diff --git a/.azuredevops/rocm-ci.yml b/.azuredevops/rocm-ci.yml index 4161c2d5a4..b37b8cc27f 100644 --- a/.azuredevops/rocm-ci.yml +++ b/.azuredevops/rocm-ci.yml @@ -14,6 +14,7 @@ trigger: branches: include: - develop + - amd-develop paths: exclude: - .github diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 459315e58b..bd597344ea 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @shumway @vidyasagar-amd # Documentation files -docs/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk -*.md @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk -*.rst @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk -.readthedocs.yaml @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @junliume @illsilin @carlushuang @aosewski @poyenc @geyyer @bartekxk +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000000..56f2acee71 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,10 @@ +We'd love for you to contribute to our source code! + +Some helpful links: + +- [Code of Conduct guidelines](https://www.contributor-covenant.org/version/2/1/code_of_conduct/code_of_conduct.txt) +- [New issue guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/ISSUE_TEMPLATE.md) +- [Submitting a pull request guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/PULL_REQUEST_TEMPLATE.md) +- [Maintainers](https://github.com/rocm/composable_kernel/blob/develop/CONTRIBUTORS.md) +- [General information](https://github.com/rocm/composable_kernel/blob/develop/README.md) +- [ROCm documentation](https://rocm.docs.amd.com/en/latest/how-to/llm-fine-tuning-optimization/optimizing-with-composable-kernel.html) \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000000..263cc3480d --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,14 @@ +When creating an issue, please check if a similar issue already exists. + +### When reporting a bug, please include: +- [ ] A descriptive title +- [ ] An isolated way to reproduce the behavior (preferably a docker container with a repro) +- [ ] ROCm version, clang version, Composable Kernel commit pin +- [ ] Environment variables +- [ ] The behavior you expect to see, and the behavior you actually see + +### When requesting a feature, please include: +- [ ] A descriptive title +- [ ] A detailed description of the problem you are trying to solve +- [ ] An overview of the suggested solution +- [ ] Explanation why the solution is an improvement \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000..8a988ad1c9 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,20 @@ +## Proposed changes + +Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request. + +## Checklist + +Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. + +- [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally +- [ ] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. +- [ ] I have added inline documentation which enables the maintainers with understanding the motivation +- [ ] I have removed the stale documentation which is no longer relevant after this pull request +- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request +- [ ] I have run `clang-format` on all changed files +- [ ] Any dependent changes have been merged + +## Discussion + +If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered + diff --git a/.gitignore b/.gitignore index f4d5ff7abd..e4dd8f7513 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,8 @@ _static/ _templates/ _toc.yml _doxygen/ +docs/doxygen/html +docs/doxygen/xml # JetBrains IDE .idea/ @@ -66,3 +68,6 @@ build*/ # Python cache __pycache__/ + +.cache/ + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100755 new mode 100644 index d6700ae05b..e4e85651f6 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,3 +12,27 @@ repos: verbose: false language: script types: [c++] + - id: remove-exec-bit + name: Remove executable bit from non-executable files + entry: script/remove_exec_bit.sh + language: script + types_or: [c++, text] + verbose: true + - id: ruff-check + name: Ruff Linter + entry: ruff check --fix + language: python + types: [python] + additional_dependencies: [ruff] + - id: ruff-format + name: Ruff Formatter + entry: ruff format + language: python + types: [python] + additional_dependencies: [ruff] + - id: run-remod-if-ck-tile-changed + name: Run remod.py if ck_tile files changed + entry: script/remod_for_ck_tile.sh + language: script + always_run: true + pass_filenames: false diff --git a/CHANGELOG.md b/CHANGELOG.md index dec6334cf5..17f9455feb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,53 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). +## Composable Kernel 1.1.0 for ROCm 6.5.0 + +### Added + +* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data +* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. +* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). +* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). +* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). +* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). +* Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added support for Multiple D GEMM +* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types +* Added support for FP16 2:4 structured sparsity to universal GEMM. +* Added support for Split K for grouped convolution backward data. +* Added logit soft-capping support for fMHA forward kernels. +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) +* Added benchmarking support for tile engine GEMM. +* Added Ping-pong scheduler support for GEMM operation along the K dimension. +* Added rotating buffer feature for CK_Tile GEMM. +* Added int8 support for CK_TILE GEMM. + +### Optimized + + +* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) +* Added Vectorize Transpose optimization for CK Tile (#2131) +* Added the asynchronous copy for gfx950 (#2425) + + +### Fixes + +None + +### Changes + +* Removed support for gfx940 and gfx941 targets (#1944) +* Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876) +* DL and DPP kernels are now enabled by default. +* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced. +* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced. +* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced. + +### Known issues + +None + ## Composable Kernel 1.1.0 for ROCm 6.1.0 ### Additions diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d5f053fc5..eaf22d3f35 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,13 +26,21 @@ set(version 1.1.0) project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) +option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) +option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) + +# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8" +# CK Codegen requires dataclass which is added in Python 3.7 +# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04 if(NOT CK_USE_ALTERNATIVE_PYTHON) - find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) + find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED) else() - message("Using alternative python version") + message(STATUS "Using alternative python version") set(EXTRA_PYTHON_PATH) + # this is overly restrictive, we may need to be more flexible on the following string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}") - message("alternative python path is: ${EXTRA_PYTHON_PATH}") + message(STATUS "alternative python path is: ${EXTRA_PYTHON_PATH}") find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}") set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}") @@ -72,7 +80,7 @@ if (DTYPES) add_definitions(-DCK_ENABLE_BF16) set(CK_ENABLE_BF16 "ON") endif() - message("DTYPES macro set to ${DTYPES}") + message(STATUS "DTYPES macro set to ${DTYPES}") else() add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8) set(CK_ENABLE_INT8 "ON") @@ -88,15 +96,31 @@ endif() add_compile_options(-Wno-bit-int-extension) add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-switch-default) +add_compile_options(-Wno-unique-object-duplication) -if(DL_KERNELS) +# Recent change in compiler makes this warning ON by default, which led to compile errors. +add_compile_options(-Wno-nrvo) + +if(NOT DISABLE_DL_KERNELS) add_definitions(-DDL_KERNELS) + set(DL_KERNELS "ON") set(CK_ENABLE_DL_KERNELS "ON") endif() +if(NOT DISABLE_DPP_KERNELS) + add_definitions(-DDPP_KERNELS) + set(DPP_KERNELS "ON") + set(CK_ENABLE_DPP_KERNELS "ON") +endif() +option(CK_USE_CODEGEN "Enable codegen library" OFF) +if(CK_USE_CODEGEN) + add_definitions(-DCK_USE_CODEGEN) +endif() -if(INSTANCES_ONLY) - add_definitions(-DINSTANCES_ONLY) - set(CK_ENABLE_INSTANCES_ONLY "ON") +option(CK_TIME_KERNEL "Enable kernel time tracking" ON) +if(CK_TIME_KERNEL) + add_definitions(-DCK_TIME_KERNEL=1) +else() + add_definitions(-DCK_TIME_KERNEL=0) endif() # enable prevent in source build @@ -128,79 +152,103 @@ rocm_setup_version(VERSION ${version}) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}") -message("GPU_TARGETS= ${GPU_TARGETS}") - -find_package(hip) +message(STATUS "GPU_TARGETS= ${GPU_TARGETS}") +message(STATUS "GPU_ARCHS= ${GPU_ARCHS}") +if(GPU_ARCHS) + #disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package + unset(GPU_TARGETS CACHE) + unset(AMDGPU_TARGETS CACHE) +endif() +if(GPU_TARGETS) + set(USER_GPU_TARGETS 1) +else() + set(USER_GPU_TARGETS 0) +endif() +find_package(hip REQUIRED) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}") -message("hip_version_flat=${hip_VERSION_FLAT}") +message(STATUS "hip_version_flat=${hip_VERSION_FLAT}") -message("checking which targets are supported") -#This is the list of targets to be used in case GPU_TARGETS is not set on command line -#These targets will be filtered and only supported ones will be used -#Setting GPU_TARGETS on command line will override this list -if(NOT PROFILER_ONLY) - if(NOT ENABLE_ASAN_PACKAGING) - #build CK for all supported targets - if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) - # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above - rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") - else() - rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") - endif() - else() - #build CK only for xnack-supported targets - rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+") - set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) +message(STATUS "checking which targets are supported") +#In order to build just the CK library (without tests and examples) for all supported GPU targets +#use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" +#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts. +# +#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures. +if(NOT ENABLE_ASAN_PACKAGING) + if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) + # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") + elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600300000 AND ${hip_VERSION_FLAT} LESS 600400000) + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") + elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000 AND ${hip_VERSION_FLAT} LESS 600443483) + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950") + elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600443483) + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx950;gfx10-3-generic;gfx11-generic;gfx12-generic") endif() else() - add_definitions(-DPROFILER_ONLY) - set(GPU_TARGETS "" CACHE STRING "" FORCE) - if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") - endif() - if(GPU_ARCH MATCHES "gfx90") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") - elseif(GPU_ARCH MATCHES "gfx94") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx940;gfx941;gfx942") - elseif(GPU_ARCH MATCHES "gfx10") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") - elseif(GPU_ARCH MATCHES "gfx11") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") - elseif(GPU_ARCH MATCHES "gfx12") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") - else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") - endif() - set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) + #build CK only for xnack-supported targets when using ASAN + set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+") endif() -message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") - -if(GPU_TARGETS) - message("Building CK for the following targets: ${GPU_TARGETS}") +#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list +#otherwise, if user set GPU_TARGETS, use that set of targets +if(GPU_ARCHS) + set(CK_GPU_TARGETS ${GPU_ARCHS}) else() - message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}") + if(USER_GPU_TARGETS) + set(CK_GPU_TARGETS ${GPU_TARGETS}) + endif() endif() +#if the user did not set GPU_TARGETS, delete whatever was set by HIP package +if(NOT USER_GPU_TARGETS) + set(GPU_TARGETS "") +endif() +#make sure all the targets on the list are actually supported by the current compiler +rocm_check_target_ids(SUPPORTED_GPU_TARGETS + TARGETS ${CK_GPU_TARGETS}) -if (GPU_TARGETS) - if (GPU_TARGETS MATCHES "gfx9") - add_definitions(-DCK_USE_XDL) - set(CK_USE_XDL "ON") - endif() - if (GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") - add_definitions(-DCK_USE_WMMA) - set(CK_USE_WMMA "ON") - endif() -else() - add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) +message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") + +if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") + message(STATUS "Enabling XDL instances") + add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") + message(STATUS "Enabling XDL FP8 gemms on native architectures") + add_definitions(-DCK_USE_GFX94) + set(CK_USE_GFX94 "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") + message(STATUS "Enabling WMMA instances") + add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") + message(STATUS "Enabling WMMA FP8 gemms on native architectures") + add_definitions(-DCK_USE_WMMA_FP8) + set(CK_USE_WMMA_FP8 "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_OCP_FP8) + set(CK_USE_OCP_FP8 "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx94") + add_definitions(-DCK_USE_FNUZ_FP8) + set(CK_USE_FNUZ_FP8 "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) + set(CK_USE_NATIVE_MX_SUPPORT "ON") +endif() + +option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) +if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) + add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) + set(CK_USE_FP8_ON_UNSUPPORTED_ARCH "ON") +endif() # CK config file to record supported datatypes, etc. configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/config.h) @@ -208,25 +256,32 @@ configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/con if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302) check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK) if(HAS_NO_OFFLOAD_UNIFORM_BLOCK) - message("Adding the fno-offload-uniform-block compiler flag") + message(STATUS "Adding the fno-offload-uniform-block compiler flag") add_compile_options(-fno-offload-uniform-block) endif() endif() +if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000) + check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION) + if(HAS_LSR_DROP_SOLUTION) + message(STATUS "Adding the lsr-drop-solution=1 compiler flag") + add_compile_options("SHELL: -mllvm --lsr-drop-solution=1") + endif() +endif() if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED) if(HAS_ENABLE_POST_MISCHED) - message("Adding the enable-post-misched=0 compiler flag") + message(STATUS "Adding the enable-post-misched=0 compiler flag") add_compile_options("SHELL: -mllvm -enable-post-misched=0") endif() endif() set(check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) -if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132 AND ${hip_VERSION_FLAT} LESS 600300000) - message("Adding the amdgpu-coerce-illegal-types=1") +if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132) + message(STATUS "Adding the amdgpu-coerce-illegal-types=1") add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1") endif() if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) - message("Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false") + message(STATUS "Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false") add_compile_options("SHELL: -mllvm -amdgpu-early-inline-all=true") add_compile_options("SHELL: -mllvm -amdgpu-function-calls=false") endif() @@ -259,17 +314,24 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) +option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) add_compile_options(-Wno-bit-int-extension) - message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") + message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() if(USE_OPT_GFX11) add_compile_options(-mcumode) add_compile_options(-mno-wavefrontsize64) - message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") + message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") +endif() + +if(ENABLE_ASM_DUMP) + add_compile_options(--save-temps) + add_compile_options(-Wno-gnu-line-marker) + message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() ## Threads @@ -281,7 +343,7 @@ link_libraries(Threads::Threads) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") +message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") # https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_macros.html # _GLIBCXX_ASSERTIONS @@ -297,7 +359,7 @@ endif() set(CMAKE_HIP_PLATFORM amd) set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) set(CMAKE_HIP_EXTENSIONS ON) -message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") +message(STATUS "CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") ## OpenMP if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") @@ -312,16 +374,15 @@ else() find_package(OpenMP REQUIRED) endif() -message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") -message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") -message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") -message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") +message(STATUS "OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") +message(STATUS "OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") +message(STATUS "OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") +message(STATUS "OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") link_libraries(${OpenMP_gomp_LIBRARY}) link_libraries(${OpenMP_pthread_LIBRARY}) ## HIP -find_package(HIP REQUIRED) # Override HIP version in config.h, if necessary. # The variables set by find_package() can't be overwritten, # therefore let's use intermediate variables. @@ -348,146 +409,152 @@ else() add_compile_definitions(__HIP_PLATFORM_HCC__=1) endif() -## tidy include(EnableCompilerWarnings) +## tidy set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") - set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) +set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) # Enable tidy on hip elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU") - set(CK_TIDY_ERRORS ALL) +set(CK_TIDY_ERRORS ALL) endif() -include(ClangTidy) -enable_clang_tidy( - CHECKS - * - -abseil-* - -android-cloexec-fopen - # Yea we shouldn't be using rand() - -cert-msc30-c - -bugprone-exception-escape - -bugprone-macro-parentheses - -cert-env33-c - -cert-msc32-c - -cert-msc50-cpp - -cert-msc51-cpp - -cert-dcl37-c - -cert-dcl51-cpp - -clang-analyzer-alpha.core.CastToStruct - -clang-analyzer-optin.performance.Padding - -clang-diagnostic-deprecated-declarations - -clang-diagnostic-extern-c-compat - -clang-diagnostic-unused-command-line-argument - -cppcoreguidelines-avoid-c-arrays - -cppcoreguidelines-avoid-magic-numbers - -cppcoreguidelines-explicit-virtual-functions - -cppcoreguidelines-init-variables - -cppcoreguidelines-macro-usage - -cppcoreguidelines-non-private-member-variables-in-classes - -cppcoreguidelines-pro-bounds-array-to-pointer-decay - -cppcoreguidelines-pro-bounds-constant-array-index - -cppcoreguidelines-pro-bounds-pointer-arithmetic - -cppcoreguidelines-pro-type-member-init - -cppcoreguidelines-pro-type-reinterpret-cast - -cppcoreguidelines-pro-type-union-access - -cppcoreguidelines-pro-type-vararg - -cppcoreguidelines-special-member-functions - -fuchsia-* - -google-explicit-constructor - -google-readability-braces-around-statements - -google-readability-todo - -google-runtime-int - -google-runtime-references - -hicpp-vararg - -hicpp-braces-around-statements - -hicpp-explicit-conversions - -hicpp-named-parameter - -hicpp-no-array-decay - # We really shouldn't use bitwise operators with signed integers, but - # opencl leaves us no choice - -hicpp-avoid-c-arrays - -hicpp-signed-bitwise - -hicpp-special-member-functions - -hicpp-uppercase-literal-suffix - -hicpp-use-auto - -hicpp-use-equals-default - -hicpp-use-override - -llvm-header-guard - -llvm-include-order - #-llvmlibc-* - -llvmlibc-restrict-system-libc-headers - -llvmlibc-callee-namespace - -llvmlibc-implementation-in-namespace - -llvm-else-after-return - -llvm-qualified-auto - -misc-misplaced-const - -misc-non-private-member-variables-in-classes - -misc-no-recursion - -modernize-avoid-bind - -modernize-avoid-c-arrays - -modernize-pass-by-value - -modernize-use-auto - -modernize-use-default-member-init - -modernize-use-equals-default - -modernize-use-trailing-return-type - -modernize-use-transparent-functors - -performance-unnecessary-value-param - -readability-braces-around-statements - -readability-else-after-return - # we are not ready to use it, but very useful - -readability-function-cognitive-complexity - -readability-isolate-declaration - -readability-magic-numbers - -readability-named-parameter - -readability-uppercase-literal-suffix - -readability-convert-member-functions-to-static - -readability-qualified-auto - -readability-redundant-string-init - # too many narrowing conversions in our code - -bugprone-narrowing-conversions - -cppcoreguidelines-narrowing-conversions - -altera-struct-pack-align - -cppcoreguidelines-prefer-member-initializer - ${CK_TIDY_CHECKS} - ${CK_TIDY_ERRORS} - HEADER_FILTER - "\.hpp$" - EXTRA_ARGS - -DCK_USE_CLANG_TIDY -) +if(ENABLE_CLANG_CPP_CHECKS) + include(ClangTidy) + enable_clang_tidy( + CHECKS + * + -abseil-* + -android-cloexec-fopen + # Yea we shouldn't be using rand() + -cert-msc30-c + -bugprone-exception-escape + -bugprone-macro-parentheses + -cert-env33-c + -cert-msc32-c + -cert-msc50-cpp + -cert-msc51-cpp + -cert-dcl37-c + -cert-dcl51-cpp + -clang-analyzer-alpha.core.CastToStruct + -clang-analyzer-optin.performance.Padding + -clang-diagnostic-deprecated-declarations + -clang-diagnostic-extern-c-compat + -clang-diagnostic-unused-command-line-argument + -cppcoreguidelines-avoid-c-arrays + -cppcoreguidelines-avoid-magic-numbers + -cppcoreguidelines-explicit-virtual-functions + -cppcoreguidelines-init-variables + -cppcoreguidelines-macro-usage + -cppcoreguidelines-non-private-member-variables-in-classes + -cppcoreguidelines-pro-bounds-array-to-pointer-decay + -cppcoreguidelines-pro-bounds-constant-array-index + -cppcoreguidelines-pro-bounds-pointer-arithmetic + -cppcoreguidelines-pro-type-member-init + -cppcoreguidelines-pro-type-reinterpret-cast + -cppcoreguidelines-pro-type-union-access + -cppcoreguidelines-pro-type-vararg + -cppcoreguidelines-special-member-functions + -fuchsia-* + -google-explicit-constructor + -google-readability-braces-around-statements + -google-readability-todo + -google-runtime-int + -google-runtime-references + -hicpp-vararg + -hicpp-braces-around-statements + -hicpp-explicit-conversions + -hicpp-named-parameter + -hicpp-no-array-decay + # We really shouldn't use bitwise operators with signed integers, but + # opencl leaves us no choice + -hicpp-avoid-c-arrays + -hicpp-signed-bitwise + -hicpp-special-member-functions + -hicpp-uppercase-literal-suffix + -hicpp-use-auto + -hicpp-use-equals-default + -hicpp-use-override + -llvm-header-guard + -llvm-include-order + #-llvmlibc-* + -llvmlibc-restrict-system-libc-headers + -llvmlibc-callee-namespace + -llvmlibc-implementation-in-namespace + -llvm-else-after-return + -llvm-qualified-auto + -misc-misplaced-const + -misc-non-private-member-variables-in-classes + -misc-no-recursion + -modernize-avoid-bind + -modernize-avoid-c-arrays + -modernize-pass-by-value + -modernize-use-auto + -modernize-use-default-member-init + -modernize-use-equals-default + -modernize-use-trailing-return-type + -modernize-use-transparent-functors + -performance-unnecessary-value-param + -readability-braces-around-statements + -readability-else-after-return + # we are not ready to use it, but very useful + -readability-function-cognitive-complexity + -readability-isolate-declaration + -readability-magic-numbers + -readability-named-parameter + -readability-uppercase-literal-suffix + -readability-convert-member-functions-to-static + -readability-qualified-auto + -readability-redundant-string-init + # too many narrowing conversions in our code + -bugprone-narrowing-conversions + -cppcoreguidelines-narrowing-conversions + -altera-struct-pack-align + -cppcoreguidelines-prefer-member-initializer + ${CK_TIDY_CHECKS} + ${CK_TIDY_ERRORS} + HEADER_FILTER + "\.hpp$" + EXTRA_ARGS + -DCK_USE_CLANG_TIDY + ) -include(CppCheck) -enable_cppcheck( - CHECKS - warning - style - performance - portability - SUPPRESS - ConfigurationNotChecked - constStatement - duplicateCondition - noExplicitConstructor - passedByValue - preprocessorErrorDirective - shadowVariable - unusedFunction - unusedPrivateFunction - unusedStructMember - unmatchedSuppression - FORCE - SOURCES - library/src - INCLUDE - ${CMAKE_CURRENT_SOURCE_DIR}/include - ${CMAKE_CURRENT_BINARY_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR}/library/include - DEFINE - CPPCHECK=1 - __linux__=1 -) + include(CppCheck) + enable_cppcheck( + CHECKS + warning + style + performance + portability + SUPPRESS + ConfigurationNotChecked + constStatement + duplicateCondition + noExplicitConstructor + passedByValue + preprocessorErrorDirective + shadowVariable + unusedFunction + unusedPrivateFunction + unusedStructMember + unmatchedSuppression + FORCE + SOURCES + library/src + INCLUDE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/library/include + DEFINE + CPPCHECK=1 + __linux__=1 + ) +else() + function(clang_tidy_check TARGET) + # stub out empty function if clang tidy is not enabled + endfunction() +endif() set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) @@ -506,7 +573,7 @@ if(BUILD_DEV) add_compile_options(-Werror) add_compile_options(-Weverything) endif() -message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") add_compile_options(-fcolor-diagnostics) @@ -515,7 +582,16 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS add_compile_options(-fdiagnostics-color=always) endif() -add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +if(NOT MIOPEN_REQ_LIBS_ONLY) + # make check runs the entire set of examples and tests + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + # make smoke runs the tests and examples that runs within 30 seconds on gfx90a + add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST") + # make regression runs the tests and examples that runs for more 30 seconds on gfx90a + add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST") +endif() + + file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) @@ -557,10 +633,14 @@ ENDIF() ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) + +option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF) +option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) + add_subdirectory(library) -if(NOT DEFINED INSTANCES_ONLY) - if(NOT DEFINED PROFILER_ONLY) +if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name @@ -571,24 +651,21 @@ if(NOT DEFINED INSTANCES_ONLY) PACKAGE_NAME examples ) add_subdirectory(example) - add_subdirectory(test) - - rocm_package_setup_component(profiler - LIBRARY_NAME composablekernel - PACKAGE_NAME ckprofiler - ) - add_subdirectory(profiler) - else() - #When building PROFILER_ONLY, label the package with GPU_ARCH - rocm_package_setup_component(profiler - LIBRARY_NAME composablekernel - PACKAGE_NAME ckprofiler_${GPU_ARCH} - ) - add_subdirectory(profiler) - endif() + add_subdirectory(tile_engine) + if(BUILD_TESTING) + add_subdirectory(test) + endif() endif() -if(NOT DEFINED PROFILER_ONLY AND (GPU_TARGETS MATCHES "gfx9" OR DEFINED INSTANCES_ONLY)) +if (NOT MIOPEN_REQ_LIBS_ONLY) + rocm_package_setup_component(profiler + LIBRARY_NAME composablekernel + PACKAGE_NAME ckprofiler + ) + add_subdirectory(profiler) +endif() + +if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) add_subdirectory(codegen) endif() diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index cdce5a4630..0900b7a1f8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -1,3 +1,4 @@ +[Back to the main page](./README.md) # Composable Kernel Developers and Contributors This is the list of developers and contributors to Composable Kernel library @@ -19,10 +20,11 @@ Tejash Shah, 2019-2020 Xiaoyan Zhou, 2020 [Jianfeng Yan](https://github.com/j4yan), 2021-2022 - +[Jun Liu](https://github.com/junliume), 2021-2024 ## Product Manager -[Jun Liu](https://github.com/junliume) +[John Afaganis](https://github.com/afagaj) + ## Contributors diff --git a/Dockerfile b/Dockerfile index 04d0357f51..0219f99238 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,32 +1,27 @@ -FROM ubuntu:20.04 +FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.2 +ARG ROCMVERSION=6.4.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" - -RUN set -xe - ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ -RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins -# Add rocm repository -RUN chmod 1777 /tmp -RUN apt-get update -RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl - ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn -RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN if [ "$ROCMVERSION" != "6.3" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.2.60200-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.2.60200-1_all.deb && \ +# Add rocm repository +RUN set -xe && \ + apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ + curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg + +RUN if [ "$ROCMVERSION" != "6.5" ]; then \ + sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ + sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ + sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ fi -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" -RUN amdgpu-install -y --usecase=rocm --no-dkms +RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \ + amdgpu-install -y --usecase=rocm --no-dkms ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache @@ -48,16 +43,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- iputils-ping \ jq \ libelf-dev \ - libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ llvm-amdgpu \ + mpich \ net-tools \ pkg-config \ - python \ - python3 \ - python3-dev \ - python3-pip \ + python3-full \ redis \ rocm-llvm-dev \ sshpass \ @@ -67,70 +59,53 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- nano \ zlib1g-dev \ zip \ + libzstd-dev \ openssh-server \ clang-format-12 \ kmod && \ apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 -RUN if [ "$ROCMVERSION" = "6.1" ]; then \ - sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev"; \ - fi -# Update the cmake to version 3.27.5 -RUN pip install --upgrade cmake==3.27.5 + rm -rf /var/lib/apt/lists/* && \ + rm -rf amdgpu-install* && \ +# Remove unnecessary rocm components that take a lot of space + apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt #Install latest ccache RUN git clone https://github.com/ccache/ccache.git && \ - cd ccache && mkdir build && cd build && cmake .. && make install - + cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools -RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip -RUN gunzip /usr/local/bin/ninja.gz -RUN chmod a+x /usr/local/bin/ninja -RUN git clone https://github.com/nico/ninjatracing.git - + cd / && \ + wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip && \ + gunzip /usr/local/bin/ninja.gz && \ + chmod a+x /usr/local/bin/ninja && \ +#Install ClangBuildAnalyzer + git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \ + cd ClangBuildAnalyzer/ && \ + make -f projects/make/Makefile && \ + cd / && \ #Install latest cppcheck -RUN git clone https://github.com/danmar/cppcheck.git && \ - cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . -WORKDIR / - -# Setup ubsan environment to printstacktrace -RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer -ENV UBSAN_OPTIONS=print_stacktrace=1 - + git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \ + cd / && \ # Install an init system -RUN wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb -RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb - -ARG PREFIX=/opt/rocm + wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb && \ + dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ # Install packages for processing the performance results -RUN pip3 install --upgrade pip -RUN pip3 install sqlalchemy==1.4.46 -RUN pip3 install pymysql -RUN pip3 install pandas==2.0.3 -RUN pip3 install setuptools-rust -RUN pip3 install sshtunnel==0.4.0 -# Setup ubsan environment to printstacktrace -ENV UBSAN_OPTIONS=print_stacktrace=1 - -ENV LC_ALL=C.UTF-8 -ENV LANG=C.UTF-8 -RUN groupadd -f render - + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ +# Add render group + groupadd -f render && \ # Install the new rocm-cmake version -RUN git clone -b master https://github.com/ROCm/rocm-cmake.git && \ - cd rocm-cmake && mkdir build && cd build && \ - cmake .. && cmake --build . && cmake --build . --target install + git clone -b master https://github.com/ROCm/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install WORKDIR / - +# Add alternative compilers, if necessary ENV compiler_version=$compiler_version ENV compiler_commit=$compiler_commit -RUN sh -c "echo compiler version = '$compiler_version'" -RUN sh -c "echo compiler commit = '$compiler_commit'" +RUN sh -c "echo compiler version = '$compiler_version'" && \ + sh -c "echo compiler commit = '$compiler_commit'" -RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ @@ -138,16 +113,10 @@ RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd else echo "using the release compiler"; \ fi -RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi - -#clean-up the deb package -RUN sh -c "rm -rf amdgpu-install*" - -#ENV HIP_CLANG_PATH='/llvm-project/build/bin' -#RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'" diff --git a/Dockerfile.compiler b/Dockerfile.compiler new file mode 100644 index 0000000000..0306057e45 --- /dev/null +++ b/Dockerfile.compiler @@ -0,0 +1,26 @@ +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1" +FROM $BASE_DOCKER +ARG compiler_version="" +ARG compiler_commit="" + +# Add alternative compilers, if necessary +ENV compiler_version=$compiler_version +ENV compiler_commit=$compiler_commit +RUN sh -c "echo compiler version = '$compiler_version'" && \ + sh -c "echo compiler commit = '$compiler_commit'" + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 16 ; \ + else echo "using the release compiler"; \ + fi + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 16 ; \ + else echo "using the release compiler"; \ + fi diff --git a/Jenkinsfile b/Jenkinsfile index 60f4d60d23..50c15701a7 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -12,6 +12,23 @@ def show_node_info() { """ } +class Version { + int major, minor, patch + @Override + String toString() { + return [major, minor, patch].findAll().join('.') + } +} +def parseVersion(String versionString) { + if (!versionString) return null + int[] tokens = versionString.split(/\./).collect { it as int } // Splits the string by '.' and converts each part to an integer. + return new Version( + major: tokens[0], + minor: tokens.length > 1 ? tokens[1] : null, + patch: tokens.length > 2 ? tokens[2] : null, + ) +} + def nthreads() { def nproc = sh(returnStdout: true, script: 'nproc') echo "Number of cores: ${nproc}" @@ -32,41 +49,43 @@ def runShell(String command){ return (output != "") } -def getDockerImageName(){ +def getBaseDockerImageName(){ def img if (params.USE_CUSTOM_DOCKER != ""){ img = "${params.USE_CUSTOM_DOCKER}" } else{ - if (params.ROCMVERSION != "6.3"){ - if (params.COMPILER_VERSION == "") { - img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" - } - else{ - if (params.COMPILER_COMMIT == ""){ - img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" - } - else{ - def commit = "${params.COMPILER_COMMIT}"[0..6] - img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" - } - } + def ROCM_numeric = parseVersion("${params.ROCMVERSION}") + if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){ + img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" + } + else{ + img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm${params.ROCMVERSION}" + } + } + return img +} + +def getDockerImageName(){ + def img + def base_name = getBaseDockerImageName() + if (params.USE_CUSTOM_DOCKER != ""){ + img = "${params.USE_CUSTOM_DOCKER}" } else{ if (params.COMPILER_VERSION == "") { - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}" + img = "${base_name}" } else{ if (params.COMPILER_COMMIT == ""){ - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" + img = "${base_name}_${params.COMPILER_VERSION}" } else{ def commit = "${params.COMPILER_COMMIT}"[0..6] - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" + img = "${base_name}_${params.COMPILER_VERSION}_${commit}" } } } - } return img } @@ -74,6 +93,7 @@ def check_host() { if ("${env.CK_SCCACHE}" != "null"){ def SCCACHE_SERVER="${env.CK_SCCACHE.split(':')[0]}" echo "sccache server: ${SCCACHE_SERVER}" + sh "chmod +w -R ${env.WORKSPACE}" sh '''ping -c 1 -p 6379 "${SCCACHE_SERVER}" | echo $? > tmp.txt''' def output = readFile(file: "tmp.txt") echo "tmp.txt contents: \$output" @@ -90,24 +110,63 @@ def build_compiler(){ return compiler } +def check_arch(){ + def arch_type = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx90a" rocminfo.log') ){ + arch_type = 1 + } + else if ( runShell('grep -n "gfx942" rocminfo.log') ) { + arch_type = 2 + } + else if ( runShell('grep -n "gfx10" rocminfo.log') ) { + arch_type = 3 + } + else if ( runShell('grep -n "gfx11" rocminfo.log') ) { + arch_type = 4 + } + else if ( runShell('grep -n "gfx12" rocminfo.log') ) { + arch_type = 5 + } + else if ( runShell('grep -n "gfx908" rocminfo.log') ) { + arch_type = 6 + } + else if ( runShell('grep -n "gfx950" rocminfo.log') ) { + arch_type = 7 + } + return arch_type +} + def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 def prefixpath = conf.get("prefixpath", "/opt/rocm") def no_cache = conf.get("no_cache", false) - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' " if(no_cache) { dockerArgs = dockerArgs + " --no-cache " } echo "Docker Args: ${dockerArgs}" - def image = getDockerImageName() + def image + if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){ + image = conf.get("docker_name", "") + echo "Using legacy docker: ${image}" + } + else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){ + image = conf.get("docker_name", "") + echo "Using special docker: ${image}" + } + else{ + image = getDockerImageName() + echo "Using default docker: ${image}" + } //Check if image exists def retimage try { - echo "Pulling down image: ${image}" + echo "Pulling image: ${image}" retimage = docker.image("${image}") - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.pull() } } @@ -123,16 +182,22 @@ def buildDocker(install_prefix){ env.DOCKER_BUILDKIT=1 checkout scm def image_name = getDockerImageName() + def base_image_name = getBaseDockerImageName() echo "Building Docker for ${image_name}" - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - + def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ + dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . " + } + else{ + dockerArgs = dockerArgs + " -f Dockerfile . " + } echo "Build Args: ${dockerArgs}" try{ if(params.BUILD_DOCKER){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" - retimage = docker.build("${image_name}", dockerArgs + ' .') - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage = docker.build("${image_name}", dockerArgs) + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' @@ -146,7 +211,7 @@ def buildDocker(install_prefix){ catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" retimage = docker.build("${image_name}", dockerArgs + ' .') - withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.push() } } @@ -160,13 +225,20 @@ def cmake_build(Map conf=[:]){ def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","") def prefixpath = conf.get("prefixpath","/opt/rocm") def setup_args = conf.get("setup_args","") - + // make sure all unit tests always run on develop branch + def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS + if (prefixpath != "/usr/local"){ setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " } def build_type_debug = (conf.get("build_type",'release') == 'debug') + // use special compiler for gfx950 + if ( check_arch() == 7){ + compiler = "/llvm-project/build/bin/clang++" + } + //cmake_env can overwrite default CXX variables. def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") @@ -183,8 +255,8 @@ def cmake_build(Map conf=[:]){ } else{ setup_args = ' -DBUILD_DEV=On' + setup_args } - if (params.DL_KERNELS){ - setup_args = setup_args + " -DDL_KERNELS=ON " + if (params.DISABLE_DL_KERNELS){ + setup_args = setup_args + " -DDISABLE_DL_KERNELS=ON " } if(build_type_debug){ @@ -213,12 +285,18 @@ def cmake_build(Map conf=[:]){ if (setup_args.contains("gfx10")){ invocation_tag="gfx10" } - if (setup_args.contains("gfx90")){ - invocation_tag="gfx90" + if (setup_args.contains("gfx908")){ + invocation_tag="gfx908" + } + if (setup_args.contains("gfx90a")){ + invocation_tag="gfx90a" } if (setup_args.contains("gfx94")){ invocation_tag="gfx94" } + if (setup_args.contains("gfx95")){ + invocation_tag="gfx95" + } echo "invocation tag: ${invocation_tag}" def redis_pre_setup_cmd = pre_setup_cmd if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") { @@ -259,6 +337,7 @@ def cmake_build(Map conf=[:]){ """) sh cmd3 } + // reduce parallelism when compiling, clang uses too much memory def nt = nthreads() def cmd @@ -266,15 +345,19 @@ def cmake_build(Map conf=[:]){ def build_cmd def execute_cmd = conf.get("execute_cmd", "") if(!setup_args.contains("NO_CK_BUILD")){ - if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ + def cmake_flags = params.NINJA_FTIME_TRACE ? "-O3 -ftime-trace" : "-O3" + if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake -G Ninja ${setup_args} .. ") - build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") - } - else{ - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") - build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") } + setup_cmd = conf.get( + "setup_cmd", + """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" ${cmake_flags} " .. """ + ) + build_cmd = conf.get( + "build_cmd", + "${build_envs} ninja -j${nt} ${config_targets}" + ) + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} @@ -292,31 +375,93 @@ def cmake_build(Map conf=[:]){ dir("build"){ //build CK sh cmd - //run tests - if(!setup_args.contains("NO_CK_BUILD")){ - if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ - sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json" + //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set + if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ + if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){ + if (params.NINJA_FTIME_TRACE) { + echo "running ninja ftime trace" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" + archiveArtifacts "clang_build_analysis.log" + } + sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace.json" archiveArtifacts "ck_build_trace.json" - sh "ninja test" + + // do not run unit tests when building instances only + if(!params.BUILD_INSTANCES_ONLY){ + if (!runAllUnitTests){ + sh "../script/launch_tests.sh" + } + else{ + sh "ninja check" + } + } + if(params.BUILD_INSTANCES_ONLY){ + // build deb packages + echo "Build packages" + sh 'ninja -j64 package' + archiveArtifacts artifacts: 'composablekernel-dev*.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb' + stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages" + } } else{ - sh "make check" + // run unit tests unless building library for all targets + if (!params.BUILD_INSTANCES_ONLY){ + if (!runAllUnitTests){ + sh "../script/launch_tests.sh" + } + else{ + sh "ninja check" + } + } } } } - // Only archive from master or develop - if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { + // Only archive from develop + if (package_build == true && env.BRANCH_NAME == "develop") { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } - if (params.RUN_CK_TILE_TESTS){ + //check the node gpu architecture + def arch = check_arch() + if (params.RUN_CK_TILE_FMHA_TESTS){ try{ - archiveArtifacts "perf_fmha_fwd_*.log" - archiveArtifacts "perf_fmha_bwd_*.log" - stash name: "perf_fmha_fwd_gfx942.log" - stash name: "perf_fmha_bwd_gfx942.log" - stash name: "perf_fmha_fwd_gfx90a.log" - stash name: "perf_fmha_bwd_gfx90a.log" + archiveArtifacts "perf_fmha_*.log" + if (arch == 1){ + stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a" + } + else if (arch == 2){ + stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942" + } + } + catch(Exception err){ + echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." + } + } + if (params.RUN_CK_TILE_TRANSPOSE_TESTS){ + try{ + archiveArtifacts "perf_transpose_*.log" + if (arch_type == 1){ + stash includes: "perf_transpose_**_gfx90a.log", name: "perf_transpose_log_gfx90a" + } + else if (arch_type == 2){ + stash includes: "perf_transpose_**_gfx942.log", name: "perf_transpose_log_gfx942" + } + } + catch(Exception err){ + echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." + } + } + if (params.RUN_CK_TILE_GEMM_TESTS){ + try{ + archiveArtifacts "perf_tile_gemm_**.log" + if (arch == 1){ + stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" + } + else if (arch == 2){ + stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942" + } } catch(Exception err){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." @@ -329,17 +474,21 @@ def buildHipClangJob(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm - - def image = getDockerImageName() def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts + if ( params.BUILD_INSTANCES_ONLY ){ + dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + } + else{ + dockerOpts = "--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + } if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') @@ -348,13 +497,13 @@ def buildHipClangJob(Map conf=[:]){ echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME - + def image def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 48, unit: 'HOURS') + timeout(time: 20, unit: 'HOURS') { cmake_build(conf) } @@ -383,35 +532,40 @@ def buildHipClangJobAndReboot(Map conf=[:]){ } } -def runCKProfiler(Map conf=[:]){ +def Build_CK(Map conf=[:]){ show_node_info() env.HSA_ENABLE_SDMA=0 + env.DOCKER_BUILDKIT=1 checkout scm - - def image = getDockerImageName() def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + } + if(params.BUILD_LEGACY_OS){ + dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' " + } def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - def variant = env.STAGE_NAME + def image def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES'){ + timeout(time: 2, unit: 'MINUTES'){ sh 'rocminfo | tee rocminfo.log' if ( !runShell('grep -n "gfx" rocminfo.log') ){ throw new Exception ("GPU not found") @@ -426,21 +580,41 @@ def runCKProfiler(Map conf=[:]){ echo "The job was cancelled or aborted" throw e } - withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 24, unit: 'HOURS') + timeout(time: 20, unit: 'HOURS') { - sh """ - rm -rf build - mkdir build - """ - dir("build"){ - unstash 'ckProfiler.tar.gz' - sh 'tar -xvf ckProfiler.tar.gz' + //check whether to run performance tests on this node + def arch = check_arch() + cmake_build(conf) + if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch == 1 ){ + echo "Run inductor codegen tests" + sh """ + python3 -m venv ${env.WORKSPACE} + . ${env.WORKSPACE}/bin/activate + python3 -m pip install pytest build setuptools setuptools_scm + python3 -m pip install . + python3 -m pytest python/test/test_gen_instances.py + """ } - + dir("build"){ + if (params.RUN_FULL_QA && arch == 2 ){ + // build deb packages + echo "Build packages" + sh 'make -j package' + archiveArtifacts artifacts: 'composablekernel*.deb' + sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' + sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb' + sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb' + stash includes: "composablekernel-**.deb", name: "packages" + } + } + // run performance tests, stash the logs, results will be processed on the master node dir("script"){ - if (params.RUN_FULL_QA){ + if (params.RUN_PERFORMANCE_TESTS){ + if (params.RUN_FULL_QA && arch == 1){ + // run full tests on gfx90a + echo "Run full performance tests" sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" archiveArtifacts "perf_gemm.log" archiveArtifacts "perf_resnet50_N256.log" @@ -455,133 +629,58 @@ def runCKProfiler(Map conf=[:]){ archiveArtifacts "perf_splitK_gemm.log" archiveArtifacts "perf_onnx_gemm.log" archiveArtifacts "perf_mixed_gemm.log" - // stash perf files to master - stash name: "perf_gemm.log" - stash name: "perf_resnet50_N256.log" - stash name: "perf_resnet50_N4.log" - stash name: "perf_batched_gemm.log" - stash name: "perf_grouped_gemm.log" - stash name: "perf_grouped_conv_fwd.log" - stash name: "perf_grouped_conv_bwd_data.log" - stash name: "perf_grouped_conv_bwd_weight.log" - stash name: "perf_gemm_bilinear.log" - stash name: "perf_reduction.log" - stash name: "perf_splitK_gemm.log" - stash name: "perf_onnx_gemm.log" - stash name: "perf_mixed_gemm.log" - //we will process results on the master node + stash includes: "perf_**.log", name: "perf_log" } - else{ + else if ( arch == 1 ){ + // run standard tests on gfx90a + echo "Run performance tests" sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" archiveArtifacts "perf_gemm.log" + archiveArtifacts "perf_onnx_gemm.log" archiveArtifacts "perf_resnet50_N256.log" archiveArtifacts "perf_resnet50_N4.log" - // stash perf files to master - stash name: "perf_gemm.log" - stash name: "perf_resnet50_N256.log" - stash name: "perf_resnet50_N4.log" - //we will process the results on the master node + stash includes: "perf_**.log", name: "perf_log" } - } - } - } - } - return retimage -} - -def runPerfTest(Map conf=[:]){ - try{ - runCKProfiler(conf) - } - catch(e){ - echo "throwing error exception in performance tests" - echo 'Exception occurred: ' + e.toString() - throw e - } - finally{ - if (!conf.get("no_reboot", false)) { - reboot() - } - } -} - -def Build_CK(Map conf=[:]){ - show_node_info() - - env.HSA_ENABLE_SDMA=0 - env.DOCKER_BUILDKIT=1 - checkout scm - - def image = getDockerImageName() - def prefixpath = conf.get("prefixpath", "/opt/rocm") - - // Jenkins is complaining about the render group - def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" - if (conf.get("enforce_xnack_on", false)) { - dockerOpts = dockerOpts + " --env HSA_XNACK=1 " - } - def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " - } - def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') - def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') - dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " - echo "Docker flags: ${dockerOpts}" - - def variant = env.STAGE_NAME - def retimage - - gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { - try { - (retimage, image) = getDockerImage(conf) - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES'){ - sh 'rocminfo | tee rocminfo.log' - if ( !runShell('grep -n "gfx" rocminfo.log') ){ - throw new Exception ("GPU not found") + // disable performance tests on gfx1030 for now. + //else if ( arch == 3){ + // run basic tests on gfx1030 + // echo "Run gemm performance tests" + // sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10" + // archiveArtifacts "perf_onnx_gemm_gfx10.log" + // stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10" + //} + else if ( arch == 4){ + // run basic tests on gfx11 + echo "Run gemm performance tests" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11" + archiveArtifacts "perf_onnx_gemm_gfx11.log" + stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11" + } + else if ( arch == 5 ){ + // run basic tests on gfx12 + echo "Run gemm performance tests" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12" + archiveArtifacts "perf_onnx_gemm_gfx12.log" + stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" + } + else if ( arch == 6 ){ + // run basic tests on gfx908 + echo "Run performance tests" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" + archiveArtifacts "perf_onnx_gemm_gfx908.log" + stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908" + } + else if ( arch == 7 ){ + // run basic tests on gfx950 + echo "Run performance tests" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950" + archiveArtifacts "perf_onnx_gemm_gfx950.log" + stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950" } - else{ - echo "GPU is OK" } } - } - } - catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ - echo "The job was cancelled or aborted" - throw e - } - withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 24, unit: 'HOURS') - { - //check whether to run performance tests on this node - def do_perf_tests = 0 - sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx1201" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){ - do_perf_tests = 1 - echo "Stash profiler and run performance tests" - } - cmake_build(conf) - dir("build"){ - //run tests and examples - //sh 'make -j check' - if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){ - //we only need the ckProfiler to run the performance tests, so we pack and stash it - //do not stash profiler on nodes where we don't need to run performance tests - sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' - stash name: "ckProfiler.tar.gz" - } - if (params.RUN_FULL_QA && do_perf_tests == 0 ){ - // build deb packages for all gfx9 targets and prepare to export - sh 'make -j package' - archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' - archiveArtifacts artifacts: 'composablekernel-tests_*.deb' - sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash name: "ckprofiler_0.2.0_amd64.deb" - } - } - if (params.hipTensor_test && do_perf_tests == 0 ){ - //build and test hipTensor + if (params.hipTensor_test && arch == 1 ){ + // build and test hipTensor on gfx90a node sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip rm -rf hipTensor-"${params.hipTensor_branch}" @@ -594,11 +693,9 @@ def Build_CK(Map conf=[:]){ ls -ltr CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install" cmake --build build -- -j + ctest --test-dir build """ } - dir("hipTensor-${params.hipTensor_branch}/build"){ - sh 'ctest' - } } } } @@ -625,11 +722,12 @@ def Build_CK_and_Reboot(Map conf=[:]){ def process_results(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm - def image = getDockerImageName() + //use older image that has user jenkins + def image = "rocm/composable_kernel:ck_ub22.04_rocm6.3" def prefixpath = "/opt/rocm" // Jenkins is complaining about the render group - def dockerOpts="--rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -637,55 +735,67 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { - try { - (retimage, image) = getDockerImage(conf) + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + try + { + echo "Pulling image: ${image}" + retimage = docker.image("${image}") + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { + retimage.pull() + } } - catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ - echo "The job was cancelled or aborted" - throw e + catch(Exception ex) + { + error "Unable to locate image: ${image}" } } withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 1, unit: 'HOURS'){ + timeout(time: 15, unit: 'MINUTES'){ try{ dir("script"){ - if (params.RUN_CK_TILE_TESTS){ + if (params.RUN_CK_TILE_FMHA_TESTS){ try{ - unstash "perf_fmha_fwd_gfx942.log" - unstash "perf_fmha_bwd_gfx942.log" - unstash "perf_fmha_fwd_gfx90a.log" - unstash "perf_fmha_bwd_gfx90a.log" + unstash "perf_fmha_log_gfx942" + unstash "perf_fmha_log_gfx90a" } catch(Exception err){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } - if (params.RUN_FULL_QA){ - // unstash perf files to master - unstash "ckprofiler_0.2.0_amd64.deb" - sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" - unstash "perf_gemm.log" - unstash "perf_resnet50_N256.log" - unstash "perf_resnet50_N4.log" - unstash "perf_batched_gemm.log" - unstash "perf_grouped_gemm.log" - unstash "perf_grouped_conv_fwd.log" - unstash "perf_grouped_conv_bwd_data.log" - unstash "perf_grouped_conv_bwd_weight.log" - unstash "perf_gemm_bilinear.log" - unstash "perf_reduction.log" - unstash "perf_splitK_gemm.log" - unstash "perf_onnx_gemm.log" - unstash "perf_mixed_gemm.log" - sh "./process_qa_data.sh" + if (params.RUN_CK_TILE_TRANSPOSE_TESTS){ + try{ + unstash "perf_transpose_log_gfx942" + unstash "perf_transpose_log_gfx90a" + } + catch(Exception err){ + echo "could not locate the Transpose performance logs: ${err.getMessage()}." + } + } + if (params.RUN_CK_TILE_GEMM_TESTS){ + try{ + unstash "perf_tile_gemm_log_gfx942" + unstash "perf_tile_gemm_log_gfx90a" + } + catch(Exception err){ + echo "could not locate the GEMM performance logs: ${err.getMessage()}." + } + } + if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){ + // unstash deb packages + unstash "packages" + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" } else{ // unstash perf files to master - unstash "perf_gemm.log" - unstash "perf_resnet50_N256.log" - unstash "perf_resnet50_N4.log" + unstash "perf_log" + try{ + unstash "perf_log_gfx11" + unstash "perf_log_gfx12" + } + catch(Exception err){ + echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}." + } sh "./process_perf_data.sh" } } @@ -702,12 +812,13 @@ def process_results(Map conf=[:]){ } } -//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2; RUN_CK_TILE_TESTS=true - 0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true - 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : "" +//launch develop branch daily jobs +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true + 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true + 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true + 0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" pipeline { agent none @@ -728,12 +839,12 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '6.2', - description: 'Specify which ROCM version to use: 6.2 (default).') + defaultValue: '6.4.1', + description: 'Specify which ROCM version to use: 6.4.1 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline-open, or leave blank (default).') + description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', @@ -747,7 +858,7 @@ pipeline { defaultValue: false, description: "Select whether to run small set of performance tests (default) or full QA") booleanParam( - name: "DL_KERNELS", + name: "DISABLE_DL_KERNELS", defaultValue: false, description: "Select whether to build DL kernels (default: OFF)") booleanParam( @@ -768,28 +879,84 @@ pipeline { description: "Run the cppcheck static analysis (default: OFF)") booleanParam( name: "RUN_PERFORMANCE_TESTS", - defaultValue: true, - description: "Run the performance tests (default: ON)") + defaultValue: false, + description: "Run the performance tests (default: OFF)") booleanParam( name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS", defaultValue: false, description: "Run the grouped conv large cases tests (default: OFF)") booleanParam( - name: "RUN_CK_TILE_TESTS", + name: "RUN_CODEGEN_TESTS", + defaultValue: true, + description: "Run codegen tests (default: ON)") + booleanParam( + name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, - description: "Run the ck_tile tests (default: OFF)") + description: "Run the ck_tile FMHA tests (default: OFF)") + booleanParam( + name: "RUN_CK_TILE_TRANSPOSE_TESTS", + defaultValue: false, + description: "Run the ck_tile Transpose tests (default: OFF)") + booleanParam( + name: "RUN_CK_TILE_GEMM_TESTS", + defaultValue: false, + description: "Run the ck_tile GEMM tests (default: OFF)") + booleanParam( + name: "RUN_TILE_ENGINE_GEMM_TESTS", + defaultValue: false, + description: "Run the tile_engine_gemm tests (default: OFF)") booleanParam( name: "BUILD_INSTANCES_ONLY", defaultValue: false, description: "Test building instances for various architectures simultaneously (default: OFF)") booleanParam( - name: "BUILD_GFX12", + name: "BUILD_GFX908", defaultValue: false, - description: "Build CK and run tests on gfx12 (default: OFF)") + description: "Build CK and run tests on gfx908 (default: OFF)") + booleanParam( + name: "BUILD_GFX90A", + defaultValue: true, + description: "Build CK and run tests on gfx90a (default: ON)") + booleanParam( + name: "BUILD_GFX942", + defaultValue: true, + description: "Build CK and run tests on gfx942 (default: ON)") + booleanParam( + name: "BUILD_GFX950", + defaultValue: false, + description: "Build CK and run tests on gfx950 (default: OFF)") + booleanParam( + name: "BUILD_GFX10", + defaultValue: true, + description: "Build CK and run tests on gfx10 (default: ON)") + booleanParam( + name: "BUILD_GFX11", + defaultValue: true, + description: "Build CK and run tests on gfx11 (default: ON)") + booleanParam( + name: "BUILD_GFX12", + defaultValue: true, + description: "Build CK and run tests on gfx12 (default: ON)") booleanParam( name: "NINJA_BUILD_TRACE", defaultValue: false, description: "Generate a ninja build trace (default: OFF)") + booleanParam( + name: "NINJA_FTIME_TRACE", + defaultValue: false, + description: "Generate a detailed time trace (default: OFF)") + booleanParam( + name: "BUILD_LEGACY_OS", + defaultValue: false, + description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)") + booleanParam( + name: "RUN_INDUCTOR_TESTS", + defaultValue: true, + description: "Run inductor codegen tests (default: ON)") + booleanParam( + name: "RUN_ALL_UNIT_TESTS", + defaultValue: false, + description: "Run all unit tests (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -798,7 +965,7 @@ pipeline { dbsshport = "${dbsshport}" dbsshuser = "${dbsshuser}" dbsshpassword = "${dbsshpassword}" - status_wrapper_creds = "${status_wrapper_creds}" + ck_git_creds = "${ck_git_creds}" gerrit_cred="${gerrit_cred}" DOCKER_BUILDKIT = "1" } @@ -834,8 +1001,8 @@ pipeline { | grep -v 'build/' \ | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ - -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 -D DL_KERNELS \ - -D __gfx908__ -D __gfx90a__ -D __gfx940__ -D __gfx941__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ + -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ + -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ -U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \ --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" } @@ -884,9 +1051,9 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases_xdl && \ - ./bin/test_grouped_convnd_fwd_large_cases_xdl""" - } + make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases""" + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -894,15 +1061,38 @@ pipeline { } } } - stage("Run CK_TILE Tests") + stage("Run Codegen Tests") { parallel { - stage("Run CK_TILE Tests on gfx90a") + stage("Run Codegen Tests on gfx90a") { when { beforeAgent true - expression { params.RUN_CK_TILE_TESTS.toBoolean() } + expression { params.RUN_CODEGEN_TESTS.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } + } + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \ + make -j64 check""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Run CK_TILE_FMHA Tests") + { + parallel + { + stage("Run CK_TILE_FMHA Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() } } agent{ label rocmnode("gfx90a") } environment{ @@ -911,17 +1101,17 @@ pipeline { make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ cd ../ && example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() } } - stage("Run CK_TILE Tests on gfx942") + stage("Run CK_TILE_FMHA Tests on gfx942") { when { beforeAgent true - expression { params.RUN_CK_TILE_TESTS.toBoolean() } + expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() } } agent{ label rocmnode("gfx942") } environment{ @@ -930,7 +1120,7 @@ pipeline { make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ cd ../ && example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ - } + } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() @@ -938,30 +1128,190 @@ pipeline { } } } + stage("Run CK_TILE_TRANSPOSE Tests") + { + parallel + { + stage("Run CK_TILE_TRANSPOSE Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_batched_transpose && \ + cd ../ && + example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run CK_TILE_TRANSPOSE Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_batched_transpose && \ + cd ../ && + example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Run CK_TILE_GEMM Tests") + { + parallel + { + stage("Run CK_TILE_GEMM Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_gemm_universal && \ + cd ../ && + example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run CK_TILE_GEMM Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_gemm_universal && \ + cd ../ && + example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Run TILE_ENGINE_GEMM Tests") + { + parallel + { + stage("Run TILE_ENGINE_GEMM Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx90a" \ + -D GEMM_DATATYPE="fp8;fp16" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j64 benchmark_gemm_fp8 && \ + ./bin/benchmark_gemm_fp8 && \ + ninja -j64 benchmark_gemm_fp16 && \ + ./bin/benchmark_gemm_fp16 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run TILE_ENGINE_GEMM Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx942" \ + -D GEMM_DATATYPE="fp8;fp16" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j128 benchmark_gemm_fp8 && \ + ./bin/benchmark_gemm_fp8 && \ + ninja -j128 benchmark_gemm_fp16 && \ + ./bin/benchmark_gemm_fp16 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Build CK and run Tests") { parallel { - stage("Build CK for all gfx9 targets") + stage("Build CK with RHEL8") { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() } + expression { params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ - -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ - -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \ - -DCMAKE_CXX_FLAGS=" -O3 " """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ - cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ - -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3" + setup_args = """ -DGPU_TARGETS="gfx942" \ + -DCMAKE_CXX_FLAGS=" -O3 " \ + -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ + execute_args = " " } steps{ - Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name) + cleanWs() + } + } + stage("Build CK with SLES15") + { + when { + beforeAgent true + expression { params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_sles15_rocm6.3" + setup_args = """ -DGPU_TARGETS="gfx942" \ + -DCMAKE_CXX_FLAGS=" -O3 " \ + -DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """ + execute_args = " " + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: " ", no_reboot:true, build_type: 'Release', docker_name: docker_name) cleanWs() } } @@ -969,15 +1319,62 @@ pipeline { { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() } + expression { (params.BUILD_GFX942.toBoolean() || params.RUN_FULL_QA.toBoolean()) && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx942") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx942" \ + -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } + stage("Build CK and run Tests on gfx950") + { + when { + beforeAgent true + expression { params.BUILD_GFX950.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx950") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx950" \ + -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx950" \ + -DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } + stage("Build CK and run Tests on gfx908") + { + when { + beforeAgent true + expression { params.BUILD_GFX908.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx908") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908" -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx908" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -989,7 +1386,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } + expression { params.BUILD_GFX90A.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx90a") } environment{ @@ -998,6 +1395,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx90a" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1005,22 +1403,27 @@ pipeline { cleanWs() } } - stage("Build CK instances for different targets") + stage("Build CK instances for all supported targets") { when { beforeAgent true - expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() } + expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx90a") } - environment{ - execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER="${build_compiler()}" \ - -D CMAKE_BUILD_TYPE=Release \ - -D INSTANCES_ONLY=ON \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ - } + agent{ label rocmnode("gfx942") } steps{ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + script { + def execute_args = params.NINJA_FTIME_TRACE ? + """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """ : + """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ + + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + } cleanWs() } } @@ -1028,15 +1431,16 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } + expression { params.BUILD_GFX10.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx1030") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx10-3-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx1030" \ + -DGPU_TARGETS="gfx10-3-generic" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1048,15 +1452,16 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } + expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx1101") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx1101" \ + -DGPU_TARGETS="gfx11-generic" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1068,15 +1473,16 @@ pipeline { { when { beforeAgent true - expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } + expression { params.BUILD_GFX12.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx1201") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1201" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx1201" \ + -DGPU_TARGETS="gfx12-generic" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1086,29 +1492,6 @@ pipeline { } } } - - stage("Performance Tests") - { - parallel - { - stage("Run ckProfiler: gfx90a") - { - when { - beforeAgent true - expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } - } - options { retry(1) } - agent{ label rocmnode("gfx90a")} - environment{ - setup_args = "NO_CK_BUILD" - } - steps{ - runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') - cleanWs() - } - } - } - } stage("Process Performance Test Results") { parallel @@ -1116,7 +1499,7 @@ pipeline { stage("Process results"){ when { beforeAgent true - expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } + expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent { label 'mici' } steps{ diff --git a/LICENSE b/LICENSE index 581b5efde5..68f6ae5746 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 4889914691..29d3d4e85a 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # Composable Kernel +> [!NOTE] +> The published documentation is available at [Composable Kernel](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/) in an organized, easy-to-read format, with search and a table of contents. The documentation source files reside in the `docs` folder of this repository. As with all ROCm projects, the documentation is open source. For more information on contributing to the documentation, see [Contribute to ROCm documentation](https://rocm.docs.amd.com/en/latest/contribute/contributing.html). + The Composable Kernel (CK) library provides a programming model for writing performance-critical kernels for machine learning workloads across multiple architectures (GPUs, CPUs, etc.). The CK library uses general purpose kernel languages, such as HIP C++. @@ -23,23 +26,15 @@ The current CK library is structured into four layers: ## General information -To build our documentation locally, use the following code: - -``` bash -cd docs -pip3 install -r sphinx/requirements.txt -python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html -``` - -You can find a list of our developers and contributors on our [Contributors](/CONTRIBUTORS.md) page. - -```note -If you use CK, cite us as follows: - -* [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???): - This paper will be available on arXiv soon. -* [CITATION.cff](/CITATION.cff) -``` +* [CK supported operations](include/ck/README.md) +* [CK Tile supported operations](include/ck_tile/README.md) +* [CK wrapper](client_example/25_wrapper/README.md) +* [CK codegen](codegen/README.md) +* [CK profiler](profiler/README.md) +* [Examples (Custom use of CK supported operations)](example/README.md) +* [Client examples (Use of CK supported operations with instance factory)](client_example/README.md) +* [Terminology](/TERMINOLOGY.md) +* [Contributors](/CONTRIBUTORS.md) CK is released under the **[MIT license](/LICENSE)**. @@ -78,7 +73,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s) you want to run CK on. You can specify single or multiple architectures. If you specify multiple architectures, - use a semicolon between each; for example, `gfx908;gfx90a;gfx940`. + use a semicolon between each; for example, `gfx908;gfx90a;gfx942`. ```bash cmake \ @@ -90,7 +85,13 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa ``` If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets - supported by the current compiler (this may take a long time). + supported by the current compiler (this may take a long time). + Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line. + + NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the + architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you + want to build the library for a list of different architectures, + you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`. 4. Build the entire CK library: @@ -103,6 +104,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa ```bash make -j install ``` + **[See Note on -j](#notes)** ## Optional post-install steps @@ -120,6 +122,15 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa You can find instructions for running each individual example in [example](/example). +* Build and run smoke/regression examples and tests: + + ```bash + make -j smoke # tests and examples that run for < 30 seconds each + ``` + ```bash + make -j regression # tests and examples that run for >= 30 seconds each + ``` + * Build ckProfiler: ```bash @@ -128,27 +139,38 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa You can find instructions for running ckProfiler in [profiler](/profiler). -Note the `-j` option for building with multiple threads in parallel. This speeds up the build significantly. -Depending on the number of CPU cores and the amount of RAM on your system, you may want to -limit the number of threads. For example, if you have a 128-core CPU and 64 Gb of RAM. +* Build our documentation locally: -By default, `-j` launches one thread per CPU core, which can cause the build to run out of memory and -crash. In such cases, you can reduce the number of threads to 32 by using `-j32`. + ``` bash + cd docs + pip3 install -r sphinx/requirements.txt + python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html + ``` + +### Notes +The `-j` option for building with multiple threads in parallel, which speeds up the build significantly. +However, `-j` launches unlimited number of threads, which can cause the build to run out of memory and +crash. On average, you should expect each thread to use ~2Gb of RAM. +Depending on the number of CPU cores and the amount of RAM on your system, you may want to +limit the number of threads. For example, if you have a 128-core CPU and 128 Gb of RAM it's advisable to use `-j32`. Additional cmake flags can be used to significantly speed-up the build: -* `INSTANCES_ONLY` (default is OFF) must be set to ON in order to build only the instances and library - while skipping all tests, examples, and profiler. This is useful in cases when you plan to use CK as a - dependency and don't plan to run any examples or tests. - * `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instances of select data types only. The main default data types are fp32 and fp16; you can safely skip other data types. -* `DL_KERNELS` (default is OFF) must be set to ON in order to build instances, such as `gemm_dl` or +* `DISABLE_DL_KERNELS` (default is OFF) must be set to ON in order not to build instances, such as `gemm_dl` or `batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most other platforms have faster instances, such as `xdl` or `wmma`, available. +* `DISABLE_DPP_KERNELS` (default is OFF) must be set to ON in order not to build instances, such as `gemm_dpp`. + These instances offer a slightly better performance of fp16 gemms on NAVI2x. But on other architectures faster alternatives are available. + +* `CK_USE_FP8_ON_UNSUPPORTED_ARCH` (default is OFF) must be set to ON in order to build instances, + such as `gemm_universal`, `gemm_universal_streamk` and `gemm_multiply_multiply` for fp8 data type for GPU targets which do not have native support for fp8 data type, such as gfx908 or gfx90a. These instances are useful on + architectures like the MI100/MI200 for the functional support only. + ## Using sccache for building The default CK Docker images come with a pre-installed version of sccache, which supports clang @@ -191,4 +213,4 @@ script/uninstall_precommit.sh ``` If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the -`git commit` command. +`git commit` command. \ No newline at end of file diff --git a/TERMINOLOGY.md b/TERMINOLOGY.md new file mode 100644 index 0000000000..e8833efb89 --- /dev/null +++ b/TERMINOLOGY.md @@ -0,0 +1,2 @@ +[Back to the main page](./README.md) +# Composable Kernel terminology \ No newline at end of file diff --git a/client_example/01_gemm/README.md b/client_example/01_gemm/README.md new file mode 100644 index 0000000000..6dcd1e2959 --- /dev/null +++ b/client_example/01_gemm/README.md @@ -0,0 +1,126 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel GEMM + +## GEMM +General matrix multiplications operation. In CK GEMM operation is called as `DeviceGemm` and requires following types as template parameters: + +* **ALayout** - A matrix layout (RowMajor/ColumnMajor). +* **BLayout** - B matrix layout (RowMajor/ColumnMajor). +* **CLayout** - B matrix layout (RowMajor/ColumnMajor). +* **ADataType** - A matrix data type. +* **BDataType** - B matrix data type. +* **CDataType** - B matrix data type. +* **AElementwiseOperation** - Fused operation on tensor A before GEMM. +* **BElementwiseOperation** - Fused operation on tensor B before GEMM. +* **CElementwiseOperation** - Fused operation on tensor C after GEMM. + +For matrices with large K dimension `DeviceGemmSplitK` implementation is available. This implementation allows user to split K dimension between work groups. This implementation uses `AtomicAdd` operation on global memory, thus need to zero-out output buffer for correct results. + +For fused operations with additional tensor there are `DeviceGemmMultipleABD` or `DeviceGemmMultipleD` operation which require following parameters: +* **DsLayout** - layouts for additional tensors for fused operations. +* **DsDataType** - data types for additional tensors for fused operations. + +For `DeviceGemmMultipleABD` **ALayout**, **BLayout**, **ADataType** and **BDataType** user should pass a tuple. + +List of the device operations in CK: + +* **DeviceGemmDl** - Device operation with DL instructions. +* **DeviceGemmDpp** - Device operation with DL instructions with DPP instructions during data load. +* **DeviceGemmWmma_CShuffle** - Device operation with WMMA instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffle_LdsDirectLoad** - Device operation with XDL instructions and CShuffle optimization for more optimized data store and direct load from global memory to shared memory. +* **DeviceGemm_Xdl_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffleV2** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. GEMM pipeline has been optimized compared to **DeviceGemm_Xdl_CShuffle**. +* **DeviceGemmXdlSkipBLds** - Device operation with XDL instructions. Load to shared memory has been skiped for B matrix. +* **DeviceGemm_Xdl_WaveletModel_CShuffle** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. Producer and consumer scheme cooperation between waves in workgroup. +* **DeviceGemmXdl** - Device operation with XDL instructions. + +Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✓| +|int8|✓| +|fp8 |✓| + +Table of supported cases by instance factory with WMMA instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✗| +|int8|✓| +|fp8 |✗| + +Table of supported cases by instance factory with DL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✗| +|fp16|✓| +|fp32|✓| +|int8|✓| +|fp8 |✗| + +Table of supported cases by instance factory with fused output elementwise operation: + +* **B Matrix Multiply + Add + Gelu** - bf16 (int8 for B matrix) +* **B Matrix Multiply + Add** - bf16 (int8 for B matrix) +* **B Matrix Multiply + Gelu** - bf16 (int8 for B matrix) +* **B Matrix Multiply** - bf16 (int8 for B matrix) + +* **Add + Add + Gelu** - fp16 +* **Add + Gelu** - fp16, bf16 (int8 for B matrix) for Row/Column/Row +* **Multiply** - fp16 +* **Add + Multiply** - fp16 +* **Add + Relu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Add + Silu** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Add** - fp16 (int8 for B matrix) for Row/Column/Row, bf16 (int8 for B matrix) for Row/Column/Row +* **Bilinear** - fp16, int8 +* **Gelu** - fp16 +* **Multiply + Add** - fp16 for Row/Column/Row and Row/Row/Row, fp16 (int8 for B matrix, fp32 for Bias) for Row/Column/Row and Row/Row/Row, +* **Quantization** - int8 + +## GEMM V2 (Universal GEMM) +General matrix multiplications operation optimized for MI300 series. Operation is called as `DeviceGemmV2` and requires following types as template parameters: + +* **ALayout** - A matrix layout (RowMajor/ColumnMajor). +* **BLayout** - B matrix layout (RowMajor/ColumnMajor). +* **CLayout** - B matrix layout (RowMajor/ColumnMajor). +* **ADataType** - A matrix data type. +* **BDataType** - B matrix data type. +* **CDataType** - B matrix data type. +* **AElementwiseOperation** - Fused operation on tensor A before GEMM. +* **BElementwiseOperation** - Fused operation on tensor B before GEMM. +* **CElementwiseOperation** - Fused operation on tensor C after GEMM. + +This implementation allows user to split K dimension between work groups. This implementation requires AtomicAdd operation on global memory (output buffer must be set to zeroes if splitK parameter is larger than one). + +List of the device operations for in CK: + +* **DeviceGemm_Xdl_CShuffleV3** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. +* **DeviceGemm_Xdl_CShuffleV3R1** - Device operation with XDL instructions with CShuffle optimization for more optimized data store. This implementation perform reduction on splitted K dimension after GEMM instead of AtomicAdd instruction. + +Table of supported cases by instance factory with XDL instruction for Row/Row/Row, Row/Column/Row, Column/Row/Row or Column/Column/Row: + +| |Is supported| +|-------|---| +|bf16|✓| +|fp16|✓| +|fp32|✗| +|int8|✗| +|fp8 (C bf16)|✓| +|fp16 (A fp8)|✓| +|fp16 (B fp8)|✓| + +## Others + +* **DeviceGemm_dequantB** - GEMM with dequantization (implemented with WMMA instructions). +* **DeviceGemmMultipleD_ABScale** - GEMM with scale for A and B matrix. +* **DeviceGemmMultipleDLayernorm** - GEMM fused with layernorm. +* **DeviceGemmMultipleDMultipleR** - GEMM fused with reductions and custom global reductions operators. +* **DeviceGemmReduce** - GEMM fused with reduction. +* **DeviceGemm_Streamk_V2** - GEMM stream K implementation. Implementation allows to use reduction instead of AtomicAdd. +* **DeviceGemmStreamK** - GEMM stream K implementation using AtomicAdd. diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index c953e21d02..2ea31bdf06 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -22,4 +22,7 @@ if(GPU_TARGETS MATCHES "gfx9") add_executable(client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp) target_link_libraries(client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) endif() + + add_executable(grouped_conv2d_fwd_ngchw grouped_conv2d_fwd_ngchw.cpp) + target_link_libraries(grouped_conv2d_fwd_ngchw PRIVATE composable_kernel::device_conv_operations) endif() diff --git a/client_example/07_grouped_convnd_fwd/README.md b/client_example/07_grouped_convnd_fwd/README.md new file mode 100644 index 0000000000..9e96df222d --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/README.md @@ -0,0 +1,68 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Forward +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Forward operation is called as `DeviceGroupedConvFwdMultipleABD` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **InLayout** - input layout (NHWGC, GNHWC, NGCHW). +* **WeiLayout** - weight layout (GKYXC). +* **DsLayout** - layouts for additional tensors for fused operations. +* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW). +* **ADataType** - input data type. Pass tuple if there is fused operation with input. +* **BDataType** - weight data type. Pass tuple if there is fused operation with weight. +* **DsDataType** - data types for additional tensors for fused operations. +* **EDataType** - Output data type. +* **AElementwiseOperation** - fused operation on tensor A (input). +* **BElementwiseOperation** - fused operation on tensor B (weight). +* **CDEElementwiseOperation** - fused operation on tensor C (output). +* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default). +* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default). + +Grouped convolution forward support tensors larger than 2GB. + +List of the device operations for grouped convolution forward in CK: + +* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3** - Device operation with XDL instructions. Optimized for AMD Instinct MI300 series. +* **DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to input, weight and output. +* **DeviceGroupedConvFwdMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions. +* **DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK** - Device operation with DL instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---|---| +|bf16 |2D, 3D|2D|2D|1D, 2D, 3D| +|fp16 |2D, 3D|2D|2D|1D, 2D, 3D| +|fp32 |2D, 3D|2D|2D|1D, 2D, 3D| +|int8 |2D, 3D|2D|2D|1D, 3D| +|fp8 |3D|✗|✗|✗| +|bf8 |3D|✗|✗|✗| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |2D, 3D|✗|2D, 3D| +|int8 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with DL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16 |✗|✗|2D| +|fp16 |✗|✗|2D| +|fp32 |✗|✗|2D| +|int8 |✗|✗|2D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Dynamic elementwise operation** - 2D/3D, NHWGC, bf16/fp16/fp32/int8 +* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **ConvInvScale** - 3D, NHWGC, fp8 +* **ConvScale** - 3D, NHWGC, fp8/bf8 +* **ConvScale + Add** - 3D, NHWGC, fp8 +* **ConvScale + Relu** - 3D, NHWGC, fp8 +* **Scale** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **Scale + Add (for A and B)** - 3D, NHWGC, bf16/fp16/fp32/int8 +* **Scale + Add + Scale + Add + Relu** - 3D, NHWGC, bf16/fp16/fp32/int8 diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp new file mode 100644 index 0000000000..480abf23d2 --- /dev/null +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NGCHW; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::NGKHW; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd() +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{C * Hi * Wi, G * C * Hi * Wi, Hi * Wi, Wi, 1}; + std::array wei_lengths{G, K, C, Y, X}; + std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{K * Ho * Wo, G * K * Ho * Wo, Ho * Wo, Wo, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + // workspace_sz will be equal to 0 for other layout than NGCHW + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd(); } diff --git a/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt index d10c39ed80..42a29a1d42 100644 --- a/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt +++ b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt @@ -1,6 +1,9 @@ add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_conv_operations) +add_executable(client_grouped_conv2d_bwd_data_ngchw grouped_conv2d_bwd_data_ngchw.cpp) +target_link_libraries(client_grouped_conv2d_bwd_data_ngchw PRIVATE composable_kernel::device_conv_operations) + add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp) target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/10_grouped_convnd_bwd_data/README.md b/client_example/10_grouped_convnd_bwd_data/README.md new file mode 100644 index 0000000000..e26fc3516e --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/README.md @@ -0,0 +1,48 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Backward Data + +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. In CK Grouped Convolution Backward Data operation is called as `DeviceGroupedConvBwdDataMultipleD` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **ALayout** - output layout (NHWGK, GNHWK, NGKHW). +* **BLayout** - weight layout (GKYXC). +* **DsLayout** - layouts for additional tensors for fused operations. +* **ELayout** - input layout (NHWGC, GNHWC, NGCHW). +* **ADataType** - output data type. +* **BDataType** - weight data type. +* **DsDataType** - data types for additional tensors for fused operations. +* **EDataType** - input data type. +* **AElementwiseOperation** - fused operation on tensor A (output). +* **BElementwiseOperation** - fused operation on tensor B (weight). +* **CDEElementwiseOperation** - fused operation on tensor C (input). +* **AComputeType** - compute data type of tensor A for mfma instruction (ADataType by default). +* **BComputeType** - compute data type of tensor B for mfma instruction (AComputeType by default). + +Grouped convolution backward data supports tensors larger than 2GB (except when image is larger than 2GB). + +List of the device operations for grouped convolution backward data in CK: + +* **DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1** - Device operation with XDL instructions and support of fused operations to input. +* **DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle** - Device operation with WMMA instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16|2D, 3D|2D, 3D|2D, 3D| +|fp16 |2D, 3D|2D, 3D|2D, 3D| +|fp32 |2D, 3D|2D, 3D|2D, 3D| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |2D, 3D|✗|2D, 3D| +|int8 |2D, 3D|✗|2D, 3D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Bilinear** - 3D, NHWGC, bf16/fp16/fp32 +* **Scale** - 3D, NHWGC, bf16/fp16/fp32 diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp new file mode 100644 index 0000000000..2309d757f0 --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NGCHW; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::NGKHW; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 256; +static constexpr ck::index_t K = 192; +static constexpr ck::index_t C = 192; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, Hi, Wi, C}; + std::array in_strides{ + C * Hi * Wi, G * C * Hi * Wi, Wi, 1, Hi * Wi}; + + std::array wei_lengths{G, K, Y, X, C}; + std::array wei_strides{K * Y * X * C, Y * X * C, X * C, C, 1}; + + std::array out_lengths{G, N, Ho, Wo, K}; + std::array out_strides{ + K * Ho * Wo, G * K * Ho * Wo, Wo, 1, Ho * Wo}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * G * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/11_grouped_conv_bwd_weight/README.md b/client_example/11_grouped_conv_bwd_weight/README.md new file mode 100644 index 0000000000..f1ba95e9cd --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/README.md @@ -0,0 +1,62 @@ +[Back to supported operations](../../../include/ck/README.md) +# Composable Kernel Grouped Convolution + +## Grouped Convolution Backward Weight + +Grouped convolution operation for 1D, 2D or 3D spatial dimensions. Convolution utilizes GEMM kernel after tensor coordinate transform. Backward weight version uses splitK feature (due to large GEMM K dimension). In CK Grouped Convolution Backward Weight operation is called as `DeviceGroupedConvBwdWeight` and requires following types as template parameters: + +* **NumDimSpatial** - number of spatial dimensions (1D, 2D, 3D). +* **InLayout** - input layout (NHWGC, GNHWC, NGCHW). +* **WeiLayout** - weight layout (GKYXC). +* **OutLayout** - output layout (NHWGK, GNHWK, NGKHW). +* **InDataType** - input data type. +* **WeiDataType** - weight data type. +* **OutDataType** - output data type. +* **InElementwiseOperation** - fused operation on tensor input. +* **WeiElementwiseOperation** - fused operation on tensor weight. +* **OutElementwiseOperation** - fused operation on tensor output. +* **ComputeTypeA** - compute data type of tensor A for mfma instruction (ADataType by default). +* **ComputeTypeB** - compute data type of tensor B for mfma instruction (ComputeTypeA by default). + +For fused operations with additional tensor there is `DeviceGroupedConvBwdWeightMultipleD` operation which requires following parameters: +* **DsLayout** - layouts for additional tensors for fused operations. +* **DsDataType** - data types for additional tensors for fused operations. + +Grouped convolution backward weight doesn't supports tensors larger than 2GB. + +List of the device operations for grouped convolution backward weight in CK: + +* **DeviceGroupedConvBwdWeight_Xdl_CShuffle** - Device operation with XDL instructions. +* **DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle** - Device operation with XDL instructions. Optimized for small C or K. +* **DeviceGroupedConvBwdWeight_Wmma_CShuffle** - Device operation with WMMA instructions. +* **DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle** - Device operation with XDL instructions and support of fused operations to output. +* **DeviceGroupedConvBwdWeight_Dl** - Device operation with DL instructions. + +Table of supported cases by instance factory with XDL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---|---| +|bf16|2D, 3D|2D, 3D|2D, 3D|✗| +|bf16(fp32 for weight)|2D, 3D|✗|✗|1D, 2D, 3D| +|fp16 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D| +|fp32 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D| + +Table of supported cases by instance factory with WMMA instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|fp16 |3D|✗|3D| +|int8 |3D|✗|3D| + +Table of supported cases by instance factory with DL instruction: + +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---| +|bf16(fp32 for weight)|1D, 2D, 3D|✗|1D, 2D, 3D| +|fp16 |1D, 2D, 3D|✗|1D, 2D, 3D| +|fp32 |1D, 2D, 3D|✗|1D, 2D, 3D| + +Table of supported cases by instance factory with fused elementwise operation: + +* **Bilinear** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32 +* **Scale** - 3D, NHWGC, bf16(fp32 for weight)/fp16/fp32 diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index dc55250bfe..67bbdfec45 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -54,7 +54,7 @@ target_link_libraries(client_conv3d_fwd_convscale_relu_amax_fp8 PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_reduction_operations - utility) + composable_kernel::utility) # Fwd convscale + AMAX add_executable(client_conv3d_fwd_convscale_amax_fp8 grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp) @@ -62,7 +62,7 @@ target_link_libraries(client_conv3d_fwd_convscale_amax_fp8 PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_reduction_operations - utility) + composable_kernel::utility) # Fwd convscale add_executable(client_conv3d_fwd_convscale_fp8 grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp) diff --git a/client_example/25_wrapper/README.md b/client_example/25_wrapper/README.md index eba3de017f..3db9a9af44 100644 --- a/client_example/25_wrapper/README.md +++ b/client_example/25_wrapper/README.md @@ -1,14 +1,9 @@ +[Back to the main page](../../README.md) # Composable Kernel wrapper GEMM tutorial -This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK) -wrapper. We present the base version of GEMM without most of the available optimizations; however, -it's worth noting that CK has kernels with different optimizations. +This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK) wrapper. We present the base version of GEMM without most of the available optimizations; however, it's worth noting that CK has kernels with different optimizations. -To implement these optimizations, you can use the CK wrapper or directly use available instances in -CK. You can also refer to the -[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp), -that uses CK wrapper based on the -[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation. +To implement these optimizations, you can use the CK wrapper or directly use available instances in CK. You can also refer to the [optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp), that uses CK wrapper based on the [`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation. The kernel definition should look similar to: diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp index 4b284c74d4..47d3e0abf9 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_bias_fastgelu_xdl_bf16_i8.cpp @@ -121,7 +121,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co constexpr ck::index_t NumDTensor = 2; using GroupedGemmKernelArgument = - ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + ck::tensor_operation::device::GroupedGemmKernelArgument; std::vector grouped_gemm_kernel_args_; grouped_gemm_kernel_args_.reserve(group_count); diff --git a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp index 6cc83e06f6..8c705d3bcc 100644 --- a/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp +++ b/client_example/31_grouped_gemm_bf16Aint8B/grouped_gemm_multiply_xdl_bf16_i8.cpp @@ -120,7 +120,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co constexpr ck::index_t NumDTensor = 1; using GroupedGemmKernelArgument = - ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + ck::tensor_operation::device::GroupedGemmKernelArgument; std::vector grouped_gemm_kernel_args_; grouped_gemm_kernel_args_.reserve(group_count); diff --git a/client_example/32_gemm_mx/CMakeLists.txt b/client_example/32_gemm_mx/CMakeLists.txt new file mode 100644 index 0000000000..558986bf5a --- /dev/null +++ b/client_example/32_gemm_mx/CMakeLists.txt @@ -0,0 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx950") + add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp) + target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/32_gemm_mx/gemm_mx_fp8.cpp b/client_example/32_gemm_mx/gemm_mx_fp8.cpp new file mode 100644 index 0000000000..6e14bf2a5f --- /dev/null +++ b/client_example/32_gemm_mx/gemm_mx_fp8.cpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using CDataType = ck::half_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; +template +inline constexpr bool is_same_v = ck::is_same::value; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AScaleLayout = Row; +using BScaleLayout = Col; + +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + mem_size_ = mem_size; + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mem_size_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t KBatch = 1; + + /* Require by mx type*/ + constexpr ck::index_t ScaleBlockSize = 32; // scaling block size + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + /* Scale stride Calculation */ + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + return static_cast(col); + else + return static_cast(row); + } + else + return static_cast(stride); + }; + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize; + auto Scale_Stride_AM = + f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + SimpleDeviceMem a_scale_device_buf( + sizeof(XDataType) * + f_matrix_space_size(Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + SimpleDeviceMem b_scale_device_buf( + sizeof(XDataType) * + f_matrix_space_size(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + using DeviceOp = + ck::tensor_operation::device::DeviceGemmMX; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v + + sizeof(BDataType) * K * N / ck::packed_size_v + + sizeof(CDataType) * M * N + + sizeof(XDataType) * M * K / ScaleBlockSize + + sizeof(XDataType) * N * K / ScaleBlockSize; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index acb57d7bb0..8fdd60f5d5 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -32,7 +32,7 @@ if (DTYPES) add_definitions(-DCK_ENABLE_BF16) set(CK_ENABLE_BF16 "ON") endif() - message("DTYPES macro set to ${DTYPES}") + message(DEBUG "DTYPES macro set to ${DTYPES}") else() add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) set(CK_ENABLE_INT8 "ON") @@ -56,13 +56,21 @@ if (GPU_TARGETS) add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() + if (GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx950") + add_definitions(-DCK_USE_OCP_FP8) + set(CK_USE_OCP_FP8 "ON") + endif() + if (GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx94") + add_definitions(-DCK_USE_FNUZ_FP8) + set(CK_USE_FNUZ_FP8 "ON") + endif() else() add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) set(CK_USE_XDL "ON") set(CK_USE_WMMA "ON") endif() -find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) +find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations utility) if(GPU_TARGETS MATCHES "gfx9") find_package(composable_kernel COMPONENTS device_contraction_operations) endif() diff --git a/client_example/README.md b/client_example/README.md index 64a7130d53..34c6733d05 100644 --- a/client_example/README.md +++ b/client_example/README.md @@ -1,3 +1,5 @@ +[Back to the main page](../README.md) +# Composable Kernel client examples ## Client application links to CK library, and therefore CK library needs to be installed before building client applications. @@ -12,8 +14,10 @@ cd client_example/build cmake \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \ +-D GPU_TARGETS="gfx908;gfx90a" \ .. ``` +You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s). ### Build client example ```bash diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake index cf77991a64..d0d30d669a 100644 --- a/cmake/ClangTidy.cmake +++ b/cmake/ClangTidy.cmake @@ -144,7 +144,7 @@ function(clang_tidy_check TARGET) # COMMAND ${CLANG_TIDY_COMMAND} $, > foreach(SOURCE ${SOURCES}) if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS")) - string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file) + string(MD5 tidy_file "${SOURCE}") set(tidy_target tidy-target-${TARGET}-${tidy_file}) add_custom_target(${tidy_target} # for some targets clang-tidy not able to get information from .clang-tidy diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake index 4bc638b446..3946cf4e8d 100644 --- a/cmake/Embed.cmake +++ b/cmake/Embed.cmake @@ -233,6 +233,8 @@ function(add_embed_library EMBED_NAME) else() target_sources(${EMBED_NAME} INTERFACE $) endif() - target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include") + target_include_directories(${EMBED_NAME} INTERFACE + $ + $) endfunction() diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 93fd306e98..0c81f8df98 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,8 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + # Werror set outside by BUILD_DEV + # -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt @@ -108,7 +109,7 @@ else() endif() list(APPEND CMAKE_COMPILER_WARNINGS -Wno-missing-field-initializers - -Wno-deprecated-declarations + -Wno-error=deprecated-declarations ) endif() add_definitions(${CMAKE_COMPILER_WARNINGS}) diff --git a/cmake/ShardInstantiation.cmake b/cmake/ShardInstantiation.cmake new file mode 100644 index 0000000000..47a5d0c48c --- /dev/null +++ b/cmake/ShardInstantiation.cmake @@ -0,0 +1,116 @@ +# Function to generate templated instantiation functions and caller function. + +# In order to reduce build times, we split the instantiation of template functions into multiple files. +# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions, +# which can be placed the TEMPLATE_FILE (typically a .in file). + +# This CMake function generates the instantiation functions and a caller function that calls all the instantiation +# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of +# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard, +# and generates a caller function that calls all the instantiation functions. + +# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation +# of the template functions in the caller function, and that code is automatically generated by this function. + +# In addition to the user-supplied template, this CMake function uses two generic templates: +# +# 1. `instantiate_shard.in`: This is the template for the instantiation functions. +# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions. + +# This function takes the following arguments: +# +# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`). +# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions. +# - NUM_SHARDS: The number of shards to generate. +# - OUTPUT_DIR: The build directory where the generated source files will be placed. +# - SRC_LIST: The list of source files to which the generated source files will be added. + + +function(generate_sharded_instantiations) + cmake_parse_arguments( + GEN_SHARDED + # No boolean arguments + "" + # Single-value arguments + "INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST" + # No multi-value arguments. + "" + ${ARGN} + ) + if (NOT GEN_SHARDED_INSTANCES_NAME) + message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_TEMPLATE_FILE) + message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_NUM_SHARDS) + message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations") + endif() + if(NOT GEN_SHARDED_OUTPUT_DIR) + message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_SRC_LIST) + message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations") + endif() + + file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR}) + + + set(GENERATED_SOURCE_FILES "") + set(EXTERN_TEMPLATE_STATEMENTS "") + set(CALL_STATEMENTS "") + message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") + + set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") + + # Generate the inc file with the template function defintions. + # This include file will hold the template function definitions and a using alias for all the shard + # instantiation functions. + configure_file( + "${GEN_SHARDED_TEMPLATE_FILE}" + "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc" + @ONLY + ) + + # Generate the sharded instantiation functions. + # This is where the build parallelization happens. + # Each of these source files will contain a single instantiation function for a shard, + # which will be called sequentially by the caller function. + set(INC_DIR "${GEN_SHARDED_INC_DIR}") + math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1") + foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID}) + set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}") + set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp") + set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in") + configure_file( + "${SHARD_FUNCTION_TEMPLATE}" + "${SHARD_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}") + set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>") + list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)") + list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)") + endforeach() + + # Join the include statements, the extern template declarations, and the call statements each + # into a single string for variable substitution in the caller function. + string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}") + string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}") + string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}") + + # Generate the caller function. + set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp") + set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in") + configure_file( + "${FUNCTION_TEMPLATE}" + "${CALLER_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}") + + # Add the generated source files to the list of source files. + # This allows the generated source files to be included in the build. + list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES}) + set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE) +endfunction() \ No newline at end of file diff --git a/cmake/call_shard.in b/cmake/call_shard.in new file mode 100644 index 0000000000..daba79b055 --- /dev/null +++ b/cmake/call_shard.in @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { + +@EXTERN_TEMPLATE_STATEMENTS@; + +void add_@INSTANCES@( + @INSTANCES@& instances) { +@CALL_STATEMENTS@; +} + +} // namespace ck::tensor_operation::device::instance diff --git a/cmake/instantiate_shard.in b/cmake/instantiate_shard.in new file mode 100644 index 0000000000..dbc0af17a9 --- /dev/null +++ b/cmake/instantiate_shard.in @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { +template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>( + @INSTANCES@& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 3b3e9f06ee..35b5cf0367 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -1,53 +1,57 @@ +cmake_minimum_required(VERSION 3.16) +project(composable_kernel_host) + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) +configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) -add_compile_options(-std=c++17) -find_package(hip) -add_custom_target(codegen) +find_package(ROCM) +include(ROCMInstallTargets) +include(ROCMTest) -# add include directories -include_directories(BEFORE - ${PROJECT_BINARY_DIR}/include - ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/library/include - ${HIP_INCLUDE_DIRS} - ) +rocm_setup_version(VERSION 1.0) list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS - ${CK_ROOT}/include/ck/*.hpp) -#printouts fot debug purposes -#message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") -#message(STATUS "RELATIVE: ${CK_ROOT}/include") + ${CK_ROOT}/include/ck/*.hpp) + add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) +add_compile_options(-std=c++17) -##message(STATUS "SOURCE_FILES: ${SOURCES}") +file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) target_link_libraries(ck_host PRIVATE ck_headers) -set_target_properties(ck_host PROPERTIES - LINKER_LANGUAGE CXX - POSITION_INDEPENDENT_CODE ON) +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) -target_include_directories(ck_host PUBLIC - $ -) +# target_include_directories(ck_host PUBLIC +# $ +# ) add_executable(ck-template-driver driver/main.cpp) target_link_libraries(ck-template-driver ck_host) -rocm_install( +rocm_install_targets( TARGETS ck_host ck_headers - EXPORT ck_hostTargets + EXPORT ck_host_targets + INCLUDE include +) +rocm_export_targets( + TARGETS ck_host ck_headers + EXPORT ck_host_targets + NAMESPACE composable_kernel:: ) -rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) -add_subdirectory(test) +if(BUILD_TESTING) + add_subdirectory(test) +endif() + diff --git a/codegen/README.md b/codegen/README.md new file mode 100644 index 0000000000..deadf3221d --- /dev/null +++ b/codegen/README.md @@ -0,0 +1,2 @@ +[Back to the main page](../README.md) +# Composable Kernel codegen \ No newline at end of file diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp index c7d295de94..7b878d0d57 100644 --- a/codegen/driver/main.cpp +++ b/codegen/driver/main.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp new file mode 100644 index 0000000000..301df0a529 --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/operation.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines all values need for an instance of fwd conv +struct Operation_Xdl_CShuffle +{ + // returns a vector of instances, only given fusion operators: will use default problem spec + static std::vector> + CreateOperations(const std::string& prologue, const std::string& epilogue); + // returns a vector of instances, given a problem spec and fusion operators + static std::vector + CreateOperations(const Problem& prob, const std::string& prologue, const std::string& epilogue); + TensorDesc A{}; + TensorDesc B{}; + TensorDesc B1{}; + TensorDesc C{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string b1_elem_op = PassThrough; + std::string c_elem_op = PassThrough; + std::string acc_elem_op = Scale; + std::string prologue = ""; + std::string epilogue = ""; + std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + // tuning parameters + operation::TileDescGemmGemm tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b0_block_transfer{}; + operation::BlockTransferDesc b1_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + bool mask_out_upper_triangle = false; + + // functions to update fusion operators if provided + void update_prologue(const std::string& prologue); + void update_epilogue(const std::string& epilogue); + /**constexpr**/ bool + IsSupported(std::size_t MRaw_, std::size_t NRaw_, std::size_t KRaw_, std::size_t Gemm1NRaw_); + // returns a templated instance + Solution ToSolution() const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp new file mode 100644 index 0000000000..30dd1487ca --- /dev/null +++ b/codegen/include/ck/host/device_batched_gemm_softmax_gemm/problem.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// defines the problem specification for a GEMM operation +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + std::size_t O = 0; + bool TransA = false; + bool TransB = false; + bool TransB1 = false; + bool TransC = false; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType B1DataType = DataType::Half; + DataType CDataType = DataType::Half; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string B1ElementOp = PassThrough; + std::string CElementOp = PassThrough; + std::string AccElementOp = Scale; + bool MaskOutUpperTriangle = false; + + // returns the correct device op file for the operation + std::string GetIncludeHeader() const; + + // returns a list of instances based on the problem spec and provided fusion operations + std::vector GetSolutions(const std::string& arch, + const std::string& prologue = "", + const std::string& epilogue = "") const; +}; + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp index 359da7d8cf..e5eeb6be15 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -41,6 +41,8 @@ struct Operation_Xdl_CShuffle operation::BlockTransferDesc b_block_transfer{}; operation::CShuffleDesc cshuffle{}; operation::CBlockTransferDesc c_block_transfer{}; + LoopScheduler loop_scheduler{}; + PipelineVersion pipeline_version{}; // functions to update fusion operators if provided void update_prologue(const std::string& prologue); diff --git a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp index f4036328ec..1c65fb71ff 100644 --- a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp +++ b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp @@ -37,8 +37,8 @@ struct Problem // returns a list of instances based on the problem spec and provided fusion operations std::vector GetSolutions(const std::string& arch, - const std::string& prologue, - const std::string& epilogue) const; + const std::string& prologue = "", + const std::string& epilogue = "") const; }; } // namespace device_gemm_multiple_d diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp index 84ef92f0a0..5a51a0002e 100644 --- a/codegen/include/ck/host/operation/gemm.hpp +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -23,6 +23,26 @@ struct TileDesc int n_Xdl_per_wave = 0; int num_gemmk_prefetch_stage = 0; }; + +struct TileDescGemmGemm +{ + int block_size = 0; + int gemm01_m_per_block = 0; + int gemm0_n_per_block = 0; + int gemm0_k_per_block = 0; + int gemm1_n_per_block = 0; + int gemm1_k_per_block = 0; + int ak1 = 0; + int bk1 = 0; + int b1k1 = 0; + int m_per_XDL = 0; + int n_per_XDL = 0; + int gemm0_m_Xdl_per_wave = 0; + int gemm0_n_Xdl_per_wave = 0; + int gemm1_n_Xdl_per_wave = 0; + int num_gemmk_prefetch_stage = 0; +}; + struct BlockTransferDesc { std::string thread_cluster_length = ""; diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp index 8bad7bf89c..b05e134176 100644 --- a/codegen/include/ck/host/types.hpp +++ b/codegen/include/ck/host/types.hpp @@ -66,6 +66,20 @@ enum class GemmType }; std::string ToString(GemmType gt); +enum class LoopScheduler +{ + Default, + Interwave, +}; +std::string ToString(LoopScheduler ls); + +enum class PipelineVersion +{ + v1, + v2 +}; +std::string ToString(PipelineVersion pv); + struct TensorDesc { DataType element; @@ -84,6 +98,7 @@ const std::string S = SequenceStr({xs...}); constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; +constexpr const char* Scale = "ck::tensor_operation::element_wise::Scale"; } // namespace host } // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm.cpp b/codegen/src/device_batched_gemm_softmax_gemm.cpp new file mode 100644 index 0000000000..cf140ead1d --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm.cpp @@ -0,0 +1,38 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// return the relevant device op file based on the operation +std::string Problem::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp"; +} + +// returns templated instances when provided with a problem specification +std::vector Problem::GetSolutions(const std::string& arch, + const std::string& prologue, + const std::string& epilogue) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::device_batched_gemm_softmax_gemm::Operation_Xdl_CShuffle::CreateOperations( + *this, prologue, epilogue); // obtains vector of instances + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); // template instance with correct values + }); + return result; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000..6029ab0c7d --- /dev/null +++ b/codegen/src/device_batched_gemm_softmax_gemm_operation_xdl_cshuffle.cpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_batched_gemm_softmax_gemm/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_batched_gemm_softmax_gemm { + +// calculate appropriate Gemm Specification based on input tensor dimensions +std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t n1, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block, + const std::size_t n1_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(integer_divide_ceil(n1, n1_per_block) * n1_per_block - n1 != 0) + spec += "O"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +// function to update prologue/epilogue with user provided operation +void Operation_Xdl_CShuffle::update_prologue(const std::string& pro) +{ + if(!prologue.empty()) + { + this->prologue = pro; + } + else + { + this->prologue = ""; + } +} + +void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) +{ + if(!epilogue.empty()) + { + this->epilogue = epi; + } + else + { + this->epilogue = ""; + } +} + +// accounts for all possible combinations of Row/Col major +static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +// Hard-code tuning parameters in modularized fashion, string them together into a vector of +// instances +std::vector Operation_Xdl_CShuffle::CreateOperations( + const Problem& prob, const std::string& prologue, const std::string& epilogue) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| NumGemmK| +// Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| Prefetch| +// | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Stage| +// | | | | | | | | | | | Wave| Wave| Wave| | + { 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, + { 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, 1}, + { 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, 1}, +// Padded fallback kernel + { 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, 1}, + { 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, 1}, +// Irregular k + { 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, 1}, + { 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, 1}, + { 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, 1}, + { 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, 1}, + { 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, 1}, + { 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, 1}, + // clang-format on + }; + + const std::vector a_block_descriptions = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Padded fallback kernel + { S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true}, +// Irregular k + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + { S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false}, + // clang-format on + }; + + const std::vector b1_block_descriptions = { + // clang-format off +// B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Padded fallback kernel + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, +// Irregular k + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + { S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false}, + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 8}, + { 1, 4}, + { 1, 8}, + { 1, 4}, +// Padded fallback kernel + { 1, 2}, + { 1, 2}, +// Irregular k + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + { 1, 2}, + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1,16>, 8}, + { S<1, 32, 1, 8>, 8}, +// Padded fallback kernel + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, +// Irregular k + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + // clang-format on + }; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b1_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b0_block_transfer = a_block_descriptions[i]; // b0 same as a + x.b1_block_transfer = b1_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.B1 = TensorDesc{prob.B1DataType, ToLayout(prob.TransB1)}; + x.C = TensorDesc{prob.CDataType, ToLayout(prob.TransC)}; + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.b1_elem_op = prob.B1ElementOp; + x.c_elem_op = prob.CElementOp; + x.acc_elem_op = prob.AccElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + prob.O, + x.tile_desc.gemm01_m_per_block, + x.tile_desc.gemm0_n_per_block, + x.tile_desc.gemm0_k_per_block, + x.tile_desc.gemm1_n_per_block); + x.update_prologue(prologue); + x.update_epilogue(epilogue); + x.mask_out_upper_triangle = prob.MaskOutUpperTriangle; + result.push_back(x); + } + return result; +} + +// set up instances when not provided with a problem specification, use default operation values and +// all possible layout combinations +std::vector> +Operation_Xdl_CShuffle::CreateOperations(const std::string& prologue, const std::string& epilogue) +{ + std::vector problems; + + Problem prob; + prob.TransA = false; + prob.TransB = true; + prob.TransB1 = false; + prob.TransC = false; + problems.push_back(prob); + + prob.MaskOutUpperTriangle = true; + problems.push_back(prob); + + return Transform(problems, + [&](const Problem& p) { return CreateOperations(p, prologue, epilogue); }); +} + +static const char* const DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate = + "ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${LayoutA}, " + "${LayoutB0}, ${LayoutB1}, ${LayoutC}, ${ADataType}, ${B0DataType}, ${B1DataType}, " + "${CDataType}, ${AccDataType}, ${CShuffleDataType}, ${AElementwiseOperation}, " + "${B0ElementwiseOperation}, ${Acc0ElementwiseOperation}, ${B1ElementwiseOperation}, " + "${CElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " + "${Gemm01MPerBlock}, ${Gemm0NPerBlock}, ${Gemm0KPerBlock}, ${Gemm1NPerBlock}, " + "${Gemm1KPerBlock}, ${AK1}, ${BK1}, ${B1K1}, ${MPerXDL}, ${NPerXDL}, ${Gemm0MXdlPerWave}, " + "${Gemm0NXdlPerWave}, ${Gemm1NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " + "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, " + "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " + "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, " + "${B0BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B0BlockTransferThreadClusterArrangeOrder}, ${B0BlockTransferSrcAccessOrder}, " + "${B0BlockTransferSrcVectorDim}, ${B0BlockTransferSrcScalarPerVector}, " + "${B0BlockTransferDstScalarPerVector_BK1}, ${B0BlockLdsExtraN}, " + "${B1BlockTransferThreadClusterLengths_BK0_N_BK1}, " + "${B1BlockTransferThreadClusterArrangeOrder}, ${B1BlockTransferSrcAccessOrder}, " + "${B1BlockTransferSrcVectorDim}, ${B1BlockTransferSrcScalarPerVector}, " + "${B1BlockTransferDstScalarPerVector_BK1}, ${B1BlockLdsExtraN}, " + "${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " + "${CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl}, " + "${CBlockTransferScalarPerVector_NWaveNPerXdl}, ${MaskOutUpperTriangle}>"; + +// use hardcoded instances from vector of operations to substitute values into instance template +Solution Operation_Xdl_CShuffle::ToSolution() const +{ + std::unordered_map values = { + {"name", + std::to_string(this->tile_desc.block_size) + "_" + + std::to_string(this->tile_desc.gemm01_m_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm0_k_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_n_per_block) + "_" + + std::to_string(this->tile_desc.gemm1_k_per_block) + "_" + + std::to_string(this->tile_desc.ak1) + "_" + std::to_string(this->tile_desc.bk1) + "_" + + std::to_string(this->tile_desc.b1k1) + "_" + + std::to_string(this->tile_desc.m_per_XDL) + "_" + + std::to_string(this->tile_desc.n_per_XDL) + "_" + + std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave) + "_" + + std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB0", ToString(this->B.layout)}, + {"LayoutB1", ToString(this->B1.layout)}, + {"LayoutC", ToString(this->C.layout)}, + {"ADataType", ToString(this->A.element)}, + {"B0DataType", ToString(this->B.element)}, + {"B1DataType", ToString(this->B1.element)}, + {"CDataType", ToString(this->C.element)}, + {"AccDataType", ToString(this->acc)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"AElementwiseOperation", this->a_elem_op}, + {"B0ElementwiseOperation", this->b_elem_op}, + {"Acc0ElementwiseOperation", this->acc_elem_op}, + {"B1ElementwiseOperation", this->b1_elem_op}, + {"CElementwiseOperation", this->c_elem_op}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"Gemm01MPerBlock", std::to_string(this->tile_desc.gemm01_m_per_block)}, + {"Gemm0NPerBlock", std::to_string(this->tile_desc.gemm0_n_per_block)}, + {"Gemm0KPerBlock", std::to_string(this->tile_desc.gemm0_k_per_block)}, + {"Gemm1NPerBlock", std::to_string(this->tile_desc.gemm1_n_per_block)}, + {"Gemm1KPerBlock", std::to_string(this->tile_desc.gemm1_k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"B1K1", std::to_string(this->tile_desc.b1k1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"Gemm0MXdlPerWave", std::to_string(this->tile_desc.gemm0_m_Xdl_per_wave)}, + {"Gemm0NXdlPerWave", std::to_string(this->tile_desc.gemm0_n_Xdl_per_wave)}, + {"Gemm1NXdlPerWave", std::to_string(this->tile_desc.gemm1_n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"B0BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b0_block_transfer.thread_cluster_length}, + {"B0BlockTransferThreadClusterArrangeOrder", + this->b0_block_transfer.thread_cluster_arrange_order}, + {"B0BlockTransferSrcAccessOrder", this->b0_block_transfer.src_access_order}, + {"B0BlockTransferSrcVectorDim", std::to_string(this->b0_block_transfer.src_vec_dim)}, + {"B0BlockTransferSrcScalarPerVector", + std::to_string(this->b0_block_transfer.src_scalar_per_vector)}, + {"B0BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b0_block_transfer.dst_scalar_per_vector_k1)}, + {"B0BlockLdsExtraN", std::to_string(this->b0_block_transfer.lds_add_extra_dim)}, + {"B1BlockTransferThreadClusterLengths_BK0_N_BK1", + this->b1_block_transfer.thread_cluster_length}, + {"B1BlockTransferThreadClusterArrangeOrder", + this->b1_block_transfer.thread_cluster_arrange_order}, + {"B1BlockTransferSrcAccessOrder", this->b1_block_transfer.src_access_order}, + {"B1BlockTransferSrcVectorDim", std::to_string(this->b1_block_transfer.src_vec_dim)}, + {"B1BlockTransferSrcScalarPerVector", + std::to_string(this->b1_block_transfer.src_scalar_per_vector)}, + {"B1BlockTransferDstScalarPerVector_BK1", + std::to_string(this->b1_block_transfer.dst_scalar_per_vector_k1)}, + {"B1BlockLdsExtraN", std::to_string(this->b1_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CBlockTransferScalarPerVector_NWaveNPerXdl", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"MaskOutUpperTriangle", std::to_string(this->mask_out_upper_triangle)}, + }; + + return Solution{InterpolateString(DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffleTemplate, values), + std::move(values)}; +} + +} // namespace device_batched_gemm_softmax_gemm +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp index fff75c1962..fe556615e0 100644 --- a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -62,6 +62,12 @@ void Operation_Xdl_CShuffle::update_epilogue(const std::string& epi) // accounts for all possible combinations of Row/Col major static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } +// clang-format off +// DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, + +// DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +// clang-format on + // Hard-code tuning parameters in modularized fashion, string them together into a vector of // instances std::vector Operation_Xdl_CShuffle::CreateOperations( @@ -83,6 +89,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, +// Irregular tile + { 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, 1}, // clang-format on }; @@ -100,6 +108,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -109,15 +119,17 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( // ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| // Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | // | | | | | | | + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, - {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, - {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, }; std::vector b_block_descriptions_rowmajor = { @@ -134,6 +146,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, // clang-format on }; @@ -151,6 +165,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, +// Irregular tile + { S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1}, // clang-format on }; @@ -167,6 +183,7 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { 1, 1}, { 1, 1}, { 1, 1}, + { 1, 1}, { 1, 1}, // clang-format on }; @@ -185,6 +202,8 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( { S<1, 16, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, { S<1, 32, 1, 8>, 8}, +// Irregular tile + { S<1, 16, 1, 4>, 1}, // clang-format on }; @@ -199,33 +218,44 @@ std::vector Operation_Xdl_CShuffle::CreateOperations( assert(tile_descriptions.size() == cshuffle_descriptions.size()); assert(tile_descriptions.size() == c_block_descriptions.size()); - // Put all values together into a single operation > store into the result vector - for(std::size_t i = 0; i < tile_descriptions.size(); i++) + const std::vector> scheduler_pipeline_descriptions = + { + {LoopScheduler::Default, PipelineVersion::v1}, + {LoopScheduler::Interwave, PipelineVersion::v1}, + {LoopScheduler::Default, PipelineVersion::v2}, + }; + for(auto [loop_scheduler, pipeline_version] : scheduler_pipeline_descriptions) { - Operation_Xdl_CShuffle x; - x.tile_desc = tile_descriptions[i]; - x.a_block_transfer = a_block_descriptions[i]; - x.b_block_transfer = b_block_descriptions[i]; - x.cshuffle = cshuffle_descriptions[i]; - x.c_block_transfer = c_block_descriptions[i]; - x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; - x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; - x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; - x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { - return TensorDesc{dt, ToLayout(trans)}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; - x.gemm_specialization = GetGemmSpec(prob.M, - prob.N, - prob.K, - x.tile_desc.m_per_block, - x.tile_desc.n_per_block, - x.tile_desc.k_per_block); - x.update_prologue(prologue); - x.update_epilogue(epilogue); - result.push_back(x); + // Put all values together into a single operation > store into the result vector + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; + x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { + return TensorDesc{dt, ToLayout(trans)}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + x.tile_desc.m_per_block, + x.tile_desc.n_per_block, + x.tile_desc.k_per_block); + x.loop_scheduler = loop_scheduler; + x.pipeline_version = pipeline_version; + x.update_prologue(prologue); + x.update_epilogue(epilogue); + result.push_back(x); + } } return result; } @@ -263,7 +293,7 @@ static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " - "${CDEBlockTransferScalarPerVector_NPerBlock}>"; + "${CDEBlockTransferScalarPerVector_NPerBlock}, ${LoopScheduler}, ${PipelineVersion}>"; // use hardcoded instances from vector of operations to substitute values into instance template Solution Operation_Xdl_CShuffle::ToSolution() const @@ -336,6 +366,8 @@ Solution Operation_Xdl_CShuffle::ToSolution() const this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, {"CDEBlockTransferScalarPerVector_NPerBlock", std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + {"LoopScheduler", ToString(this->loop_scheduler)}, + {"PipelineVersion", ToString(this->pipeline_version)}, }; return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 5b0c929db3..452cd99846 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/headers.hpp" #include "ck_headers.hpp" diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp index a8a8b10c04..a60e36ca4a 100644 --- a/codegen/src/types.cpp +++ b/codegen/src/types.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/types.hpp" #include "ck/host/stringutils.hpp" #include @@ -56,6 +59,26 @@ std::string ToString(GemmType gt) throw std::runtime_error("Incorrect gemm type"); } +std::string ToString(LoopScheduler ls) +{ + switch(ls) + { + case LoopScheduler::Default: return "ck::LoopScheduler::Default"; + case LoopScheduler::Interwave: return "ck::LoopScheduler::Interwave"; + } + throw std::runtime_error("Incorrect LoopScheduler type"); +} + +std::string ToString(PipelineVersion pv) +{ + switch(pv) + { + case PipelineVersion::v1: return "ck::PipelineVersion::v1"; + case PipelineVersion::v2: return "ck::PipelineVersion::v2"; + } + throw std::runtime_error("Incorrect PipelineVersion type"); +} + std::string SequenceStr(const std::vector& v) { return "ck::Sequence<" + diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp index 19627d4cf6..c15a9fd7d3 100644 --- a/codegen/src/utils.cpp +++ b/codegen/src/utils.cpp @@ -13,7 +13,7 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y) const std::unordered_set& get_xdlop_archs() { - static std::unordered_set supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"}; + static std::unordered_set supported_archs{"gfx90a", "gfx908", "gfx942"}; return supported_archs; } diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 3ca3f39a30..48fde531da 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,19 +1,25 @@ list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) -if(NOT INSTANCES_ONLY) - foreach(TEST_SRC ${TEST_SRCS}) - set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP) - get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) - add_executable(codegen_test_${BASE_NAME} ${TEST_SRC}) - add_dependencies(codegen codegen_test_${BASE_NAME}) - add_dependencies(tests codegen_test_${BASE_NAME}) - add_dependencies(check codegen_test_${BASE_NAME}) - add_test(NAME codegen_test_${BASE_NAME} COMMAND codegen_test_${BASE_NAME}) - message("adding test codegen_test_${BASE_NAME}") - target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host) - target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/codegen/test/include) + +# TODO: These tests need to be refactored to remove dependency on main ck +# headers and device compilation. +set(TESTS_REQUIRE_DEVICE_COMPILE + grouped_conv_fwd_multiple_d_v1 + grouped_conv_fwd_multiple_d_v2 + grouped_conv_fwd_multiple_d_v3 + grouped_conv_fwd_multiple_d_v4 +) +find_package(hip) + +foreach(TEST_SRC ${TEST_SRCS}) + get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) + rocm_add_test_executable(codegen_test_${BASE_NAME} ${TEST_SRC}) + target_link_libraries(codegen_test_${BASE_NAME} ck_rtc ck_host) + target_include_directories(codegen_test_${BASE_NAME} PUBLIC include) + if(BASE_NAME IN_LIST TESTS_REQUIRE_DEVICE_COMPILE) + target_link_libraries(codegen_test_${BASE_NAME} hip::device) target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/include) target_include_directories(codegen_test_${BASE_NAME} PUBLIC ${CK_ROOT}/library/include) - endforeach() -endif() + endif() +endforeach() diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp new file mode 100644 index 0000000000..13035df355 --- /dev/null +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -0,0 +1,82 @@ +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "common.hpp" +#include +#include +#include +#include + +using half = _Float16; + +const std::string gemm_compile_check = R"__ck__( +#include <${include}> + +extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, const ck::half_t* b1, ck::half_t* c) { + using G = ${template}; + constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${k}), ck::make_tuple(${m}, 1)), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(${n}, 1)), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${o}), ck::make_tuple(1, ${n})), + ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${o}), ck::make_tuple(${m}, 1))); + + static_assert(desc.IsValid(), "Invalid ck gemm."); + + if constexpr(desc.IsValid()) + { + ${template}::Run(desc, + 1.0, + a, + b, + b1, + c); + } +} + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + ck::host::device_batched_gemm_softmax_gemm::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + prob.O = 1024; + prob.TransB = true; + check_all check; + auto a = to_gpu(generate_buffer(1024 * 1024, 0)); + auto b = to_gpu(generate_buffer(1024 * 1024, 1)); + auto b1 = to_gpu(generate_buffer(1024 * 1024, 2)); + auto c = to_gpu(generate_buffer(1024 * 1024, 3)); + + auto solutions = prob.GetSolutions("gfx90a"); + std::cout << "Num solutions: " << solutions.size() << std::endl; + for(auto i = 0; i < solutions.size(); ++i) + { + std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; + auto&& solution = solutions[i]; + auto src = ck::host::InterpolateString(gemm_compile_check, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + options.kernel_name = "f"; + auto k = rtc::compile_kernel(srcs, options); + auto block_size = solution.GetTemplateParameter("BlockSize"); + auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("Gemm1NPerBlock"); + auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * + ck::host::integer_divide_ceil(prob.N, n_per_block); + k.launch(nullptr, grid_size * block_size, block_size)( + a.data(), b.data(), b1.data(), c.data()); + + // NOTE: Solutions where MaskOutUpperTriangle is True don't produce consistent results + CHECK(report(solution, check(rtc::from_gpu(c)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index bd7ef463fb..adc8e1ff02 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -1,136 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_gemm_multiple_d/operation.hpp" #include "ck/host/headers.hpp" #include "ck/host/stringutils.hpp" #include "ck/host/utils.hpp" -#include -#include -#include -#include -#include +#include "common.hpp" #include #include +#include +#include +#include #include +#include +#include using half = _Float16; -// using half = __fp16; - -std::vector get_headers_for_test() -{ - std::vector result; - auto hs = ck::host::GetHeaders(); - std::transform( - hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { - return {p.first, p.second}; - }); - return result; -} - -template -rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) -{ - rtc::buffer result(n); - std::mt19937 gen(seed); - std::uniform_real_distribution dis(-1.0); - std::generate(result.begin(), result.end(), [&] { return dis(gen); }); - return result; -} - -template -bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) -{ - return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { - return fabs(x - y) < atol + rtol * fabs(y); - }); -} - -std::string classify(double x) -{ - switch(std::fpclassify(x)) - { - case FP_INFINITE: return "inf"; - case FP_NAN: return "nan"; - case FP_NORMAL: return "normal"; - case FP_SUBNORMAL: return "subnormal"; - case FP_ZERO: return "zero"; - default: return "unknown"; - } -} - -template -void print_classification(const Buffer& x) -{ - std::unordered_set result; - for(const auto& i : x) - result.insert(classify(i)); - for(const auto& c : result) - std::cout << c << ", "; - std::cout << std::endl; -} - -template -void print_statistics(const Buffer& x) -{ - std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; - std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; - double num_elements = x.size(); - auto mean = - std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; - auto stddev = std::sqrt( - std::accumulate(x.begin(), - x.end(), - double{0.0}, - [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / - num_elements); - std::cout << "Mean: " << mean << ", "; - std::cout << "StdDev: " << stddev << "\n"; -} - -template -void print_preview(const Buffer& x) -{ - if(x.size() <= 10) - { - std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); - } - else - { - std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); - std::cout << "..., "; - std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); - } - std::cout << std::endl; -} - -template -struct check_all -{ - rtc::buffer data{}; - bool operator()(const rtc::buffer& x) - { - if(data.empty()) - { - data = x; - return true; - } - if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); })) - return false; - return allclose(data, x); - } -}; - -template -auto report(const Solution& solution, bool pass) -{ - return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); -} const std::string gemm_compile_check = R"__ck__( #include <${include}> extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) { using G = ${template}; - constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), + constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})), ck::make_tuple(), ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n}))); @@ -160,18 +53,19 @@ TEST_CASE(test_problem_kernel) auto b = to_gpu(generate_buffer(1024 * 1024, 1)); auto c = to_gpu(generate_buffer(1024 * 1024, 2)); - std::string epilogue = ""; - std::string prologue = ""; - - for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue)) + auto solutions = prob.GetSolutions("gfx90a"); + std::cout << "Num solutions: " << solutions.size() << std::endl; + for(auto i = 0; i < solutions.size(); ++i) { - auto src = ck::host::InterpolateString(gemm_compile_check, + std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; + auto&& solution = solutions[i]; + auto src = ck::host::InterpolateString(gemm_compile_check, {{"include", prob.GetIncludeHeader()}, {"template", solution.ToTemplateString()}, {"m", std::to_string(prob.M)}, {"n", std::to_string(prob.N)}, {"k", std::to_string(prob.K)}}); - auto srcs = get_headers_for_test(); + auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; options.kernel_name = "f"; diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp index 50290fa25a..9902caab04 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v1.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp index b558d97c78..205283e7aa 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v2.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp index e2972a93d2..2b83af2432 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v3.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp index b728096c51..fbe27e9c8b 100644 --- a/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp +++ b/codegen/test/grouped_conv_fwd_multiple_d_v4.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_op.hpp" #include "ck/host/device_grouped_conv_fwd_multiple_d/conv_fwd_problem.hpp" #include "ck/host/headers.hpp" diff --git a/codegen/test/common.hpp b/codegen/test/include/common.hpp similarity index 73% rename from codegen/test/common.hpp rename to codegen/test/include/common.hpp index 99d4c64973..b3be592e74 100644 --- a/codegen/test/common.hpp +++ b/codegen/test/include/common.hpp @@ -1,25 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + +#include "ck/host/headers.hpp" +#include +#include +#include #include #include #include #include #include -#include -#include -#include -#include +#include -std::vector get_headers_for_test() +inline std::vector create_headers_for_test() { + auto ck_headers = ck::host::GetHeaders(); std::vector result; - auto hs = ck::host::GetHeaders(); - std::transform( - hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { - return {p.first, p.second}; - }); + std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) { + std::string content; + content.reserve(p.second.size() + 1); + content.push_back(' '); // We need a whitespace before the content for hipRTC to work + content.append(p.second.data(), p.second.size()); + return rtc::src_file{p.first, std::move(content)}; + }); return result; } +inline const std::vector& get_headers_for_test() +{ + static const std::vector headers = create_headers_for_test(); + return headers; +} + template std::size_t GetSize(V mLens, V mStrides) { @@ -34,18 +48,24 @@ std::size_t GetSize(V mLens, V mStrides) return space; } -template -rtc::buffer generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +template +rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) { - std::size_t space = GetSize(mLens, mStrides); - rtc::buffer result(space); + rtc::buffer result(n); std::mt19937 gen(seed); std::uniform_real_distribution dis(-1.0); std::generate(result.begin(), result.end(), [&] { return dis(gen); }); - // std::fill(result.begin(), result.end(), 1); return result; } +template +std::enable_if_t, rtc::buffer> +generate_buffer(V mLens, V mStrides, std::size_t seed = 0) +{ + std::size_t space = GetSize(mLens, mStrides); + return generate_buffer(space, seed); +} + template bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) { @@ -54,7 +74,7 @@ bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) }); } -std::string classify(double x) +inline std::string classify(double x) { switch(std::fpclassify(x)) { diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 39497f1a21..b8a60cd633 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -1,4 +1,12 @@ +find_package(hip) file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) target_link_libraries(ck_rtc PUBLIC hip::host) +target_link_libraries(ck_rtc PUBLIC -lstdc++fs) + +option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) +if(USE_HIPRTC_FOR_CODEGEN_TESTS) + target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS) + message(STATUS "CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") +endif() diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index 5a4a4b0dd6..207f10a8e8 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -1,16 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL #include -#include +#include #include namespace rtc { struct src_file { - std::filesystem::path path; - std::string_view content; + src_file(std::filesystem::path p, std::string c) : path{std::move(p)}, content{std::move(c)} {} + fs::path path; + std::string content; }; struct compile_options @@ -19,7 +23,7 @@ struct compile_options std::string kernel_name = "main"; }; -kernel compile_kernel(const std::vector& src, +kernel compile_kernel(const std::vector& srcs, compile_options options = compile_options{}); } // namespace rtc diff --git a/codegen/test/rtc/include/rtc/filesystem.hpp b/codegen/test/rtc/include/rtc/filesystem.hpp new file mode 100644 index 0000000000..3b94b84b9f --- /dev/null +++ b/codegen/test/rtc/include/rtc/filesystem.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef GUARD_TEST_HOST_RTC_FILESYSTEM_HPP +#define GUARD_TEST_HOST_RTC_FILESYSTEM_HPP + +#include +#include + +// clang-format off +#if defined(CPPCHECK) + #define RTC_HAS_FILESYSTEM 1 + #define RTC_HAS_FILESYSTEM_TS 1 +#elif defined(_WIN32) + #if _MSC_VER >= 1920 + #define RTC_HAS_FILESYSTEM 1 + #define RTC_HAS_FILESYSTEM_TS 0 + #elif _MSC_VER >= 1900 + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 1 + #else + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 0 + #endif +#elif defined(__has_include) + #if __has_include() && __cplusplus >= 201703L + #define RTC_HAS_FILESYSTEM 1 + #else + #define RTC_HAS_FILESYSTEM 0 + #endif + #if __has_include() && __cplusplus >= 201103L + #define RTC_HAS_FILESYSTEM_TS 1 + #else + #define RTC_HAS_FILESYSTEM_TS 0 + #endif +#else + #define RTC_HAS_FILESYSTEM 0 + #define RTC_HAS_FILESYSTEM_TS 0 +#endif +// clang-format on + +#if RTC_HAS_FILESYSTEM +#include +#elif RTC_HAS_FILESYSTEM_TS +#include +#else +#error "No filesystem include available" +#endif + +namespace rtc { + +#if RTC_HAS_FILESYSTEM +namespace fs = ::std::filesystem; +#elif RTC_HAS_FILESYSTEM_TS +namespace fs = ::std::experimental::filesystem; +#endif + +} // namespace rtc + +#endif // GUARD_RTC_FILESYSTEM_HPP_ diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp index 6b523382dc..3163bb08ed 100644 --- a/codegen/test/rtc/include/rtc/hip.hpp +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -1,9 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP #include #include +#include #include +#include namespace rtc { diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index 9f38e90416..b1ee729f77 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL diff --git a/codegen/test/rtc/include/rtc/manage_ptr.hpp b/codegen/test/rtc/include/rtc/manage_ptr.hpp index 92edf12628..52b94d4b70 100644 --- a/codegen/test/rtc/include/rtc/manage_ptr.hpp +++ b/codegen/test/rtc/include/rtc/manage_ptr.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index f0fd1f72bb..2f3b26cc43 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -1,14 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR #define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR #include -#include +#include namespace rtc { struct tmp_dir { - std::filesystem::path path; + fs::path path; tmp_dir(const std::string& prefix = ""); void execute(const std::string& cmd) const; diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index d84ebf4de9..262e6bae46 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -1,13 +1,43 @@ -#include "rtc/hip.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include #include +#ifdef HIPRTC_FOR_CODEGEN_TESTS +#include +#include +#endif #include -#include -#include -#include +#include #include +#include +#include +#include +#include +#include namespace rtc { +bool EndsWith(const std::string& value, const std::string& suffix) +{ + if(suffix.size() > value.size()) + return false; + else + return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin()); +} + +std::vector SplitString(const std::string& s, char delim) +{ + std::vector elems; + std::stringstream ss(s + delim); + std::string item; + while(std::getline(ss, item, delim)) + { + elems.push_back(item); + } + return elems; +} + template T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0) { @@ -59,7 +89,7 @@ std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device // TODO: undo after extracting the codeobj // std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; } -kernel compile_kernel(const std::vector& srcs, compile_options options) +kernel clang_compile_kernel(const std::vector& srcs, compile_options options) { assert(not srcs.empty()); tmp_dir td{"compile"}; @@ -70,9 +100,9 @@ kernel compile_kernel(const std::vector& srcs, compile_options options for(const auto& src : srcs) { - std::filesystem::path full_path = td.path / src.path; - std::filesystem::path parent_path = full_path.parent_path(); - std::filesystem::create_directories(parent_path); + fs::path full_path = td.path / src.path; + fs::path parent_path = full_path.parent_path(); + fs::create_directories(parent_path); write_string(full_path.string(), src.content); if(src.path.extension().string() == ".cpp") { @@ -86,7 +116,7 @@ kernel compile_kernel(const std::vector& srcs, compile_options options td.execute(compiler() + options.flags); auto out_path = td.path / out; - if(not std::filesystem::exists(out_path)) + if(not fs::exists(out_path)) throw std::runtime_error("Output file missing: " + out); auto obj = read_buffer(out_path.string()); @@ -100,4 +130,173 @@ kernel compile_kernel(const std::vector& srcs, compile_options options return kernel{obj.data(), options.kernel_name}; } +#ifdef HIPRTC_FOR_CODEGEN_TESTS + +std::string hiprtc_error(hiprtcResult err, const std::string& msg) +{ + return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); +} + +void hiprtc_check_error(hiprtcResult err, const std::string& msg = "") +{ + if(err != HIPRTC_SUCCESS) + throw std::runtime_error(hiprtc_error(err, msg)); +} + +struct hiprtc_src_file +{ + hiprtc_src_file() = default; + hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {} + std::string path; + std::string content; +}; + +void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); } +using hiprtc_program_ptr = RTC_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy); + +template +hiprtc_program_ptr hiprtc_program_create(Ts... xs) +{ + hiprtcProgram prog = nullptr; + auto result = hiprtcCreateProgram(&prog, xs...); + hiprtc_program_ptr p{prog}; + hiprtc_check_error(result, "Create program failed."); + return p; +} + +struct hiprtc_program +{ + struct string_array + { + std::deque strings{}; + std::vector c_strs{}; + + string_array() {} + string_array(const string_array&) = delete; + + std::size_t size() const { return strings.size(); } + + const char** data() { return c_strs.data(); } + + void push_back(std::string s) + { + strings.push_back(std::move(s)); + c_strs.push_back(strings.back().c_str()); + } + }; + + hiprtc_program_ptr prog = nullptr; + string_array headers{}; + string_array include_names{}; + std::string cpp_src = ""; + std::string cpp_name = ""; + + hiprtc_program(const std::string& src, const std::string& name = "main.cpp") + : cpp_src(src), cpp_name(name) + { + create_program(); + } + + hiprtc_program(std::vector srcs) + { + for(auto&& src : srcs) + { + if(EndsWith(src.path, ".cpp")) + { + cpp_src = std::move(src.content); + cpp_name = std::move(src.path); + } + else + { + headers.push_back(std::move(src.content)); + include_names.push_back(std::move(src.path)); + } + } + create_program(); + } + + void create_program() + { + assert(not cpp_src.empty()); + assert(not cpp_name.empty()); + assert(headers.size() == include_names.size()); + prog = hiprtc_program_create(cpp_src.c_str(), + cpp_name.c_str(), + headers.size(), + headers.data(), + include_names.data()); + } + + void compile(const std::vector& options, bool quiet = false) const + { + std::vector c_options; + std::transform(options.begin(), + options.end(), + std::back_inserter(c_options), + [](const std::string& s) { return s.c_str(); }); + auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data()); + auto prog_log = log(); + if(not prog_log.empty() and not quiet) + { + std::cerr << prog_log << std::endl; + } + if(result != HIPRTC_SUCCESS) + throw std::runtime_error("Compilation failed."); + } + + std::string log() const + { + std::size_t n = 0; + hiprtc_check_error(hiprtcGetProgramLogSize(prog.get(), &n)); + if(n == 0) + return {}; + std::string buffer(n, '\0'); + hiprtc_check_error(hiprtcGetProgramLog(prog.get(), buffer.data())); + assert(buffer.back() != 0); + return buffer; + } + + std::vector get_code_obj() const + { + std::size_t n = 0; + hiprtc_check_error(hiprtcGetCodeSize(prog.get(), &n)); + std::vector buffer(n); + hiprtc_check_error(hiprtcGetCode(prog.get(), buffer.data())); + return buffer; + } +}; + +std::vector> compile_hip_src_with_hiprtc(const std::vector& srcs, + const compile_options& options) +{ + hiprtc_program prog(srcs); + auto flags = SplitString(options.flags, ' '); + prog.compile(flags); + return {prog.get_code_obj()}; +} + +static kernel hiprtc_compile_kernel(const std::vector& srcs, compile_options options) +{ + options.flags += " -I. -O3"; + options.flags += " -std=c++17"; + options.flags += " -DCK_CODE_GEN_RTC"; + options.flags += " --offload-arch=" + get_device_name(); + auto cos = compile_hip_src_with_hiprtc(srcs, options); + if(cos.size() != 1) + std::runtime_error("No code object"); + auto& obj = cos.front(); + return kernel{obj.data(), options.kernel_name}; +} + +#endif + +kernel compile_kernel(const std::vector& srcs, compile_options options) +{ +#ifdef HIPRTC_FOR_CODEGEN_TESTS + return hiprtc_compile_kernel(srcs, options); +#else + return clang_compile_kernel(srcs, options); +#endif +} + } // namespace rtc diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp index 747f83e3ba..6f16e36720 100644 --- a/codegen/test/rtc/src/hip.cpp +++ b/codegen/test/rtc/src/hip.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp index 9fe38e84ad..982e95de17 100644 --- a/codegen/test/rtc/src/kernel.cpp +++ b/codegen/test/rtc/src/kernel.cpp @@ -1,6 +1,10 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include +#include #include // extern declare the function since hip/hip_ext.h header is broken diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp index 1cc8f75b29..b36b17cce1 100644 --- a/codegen/test/rtc/src/tmp_dir.cpp +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #include #include #include @@ -31,10 +34,10 @@ std::string unique_string(const std::string& prefix) } tmp_dir::tmp_dir(const std::string& prefix) - : path(std::filesystem::temp_directory_path() / + : path(fs::temp_directory_path() / unique_string(prefix.empty() ? "ck-rtc" : "ck-rtc-" + prefix)) { - std::filesystem::create_directories(this->path); + fs::create_directories(this->path); } void tmp_dir::execute(const std::string& cmd) const @@ -43,6 +46,6 @@ void tmp_dir::execute(const std::string& cmd) const std::system(s.c_str()); } -tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); } +tmp_dir::~tmp_dir() { fs::remove_all(this->path); } } // namespace rtc diff --git a/docs/reference/Supported_Primitives_Guide.rst b/docs/conceptual/Composable-Kernel-math.rst similarity index 85% rename from docs/reference/Supported_Primitives_Guide.rst rename to docs/conceptual/Composable-Kernel-math.rst index e24acf5656..1c21fd8a11 100644 --- a/docs/reference/Supported_Primitives_Guide.rst +++ b/docs/conceptual/Composable-Kernel-math.rst @@ -1,18 +1,15 @@ .. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation + :description: Composable Kernel mathematical basis + :keywords: composable kernel, CK, ROCm, API, mathematics, algorithm .. _supported-primitives: ******************************************************************** -Supported Primitives Guide +Composable Kernel mathematical basis ******************************************************************** -This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. +This is an introduction to the math which underpins the algorithms implemented in Composable Kernel. ------------- -Softmax ------------- For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` you can decompose the softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, diff --git a/docs/conceptual/Composable-Kernel-structure.rst b/docs/conceptual/Composable-Kernel-structure.rst new file mode 100644 index 0000000000..43c3603b95 --- /dev/null +++ b/docs/conceptual/Composable-Kernel-structure.rst @@ -0,0 +1,29 @@ +.. meta:: + :description: Composable Kernel structure + :keywords: composable kernel, CK, ROCm, API, structure + +.. _what-is-ck: + +******************************************************************** +Composable Kernel structure +******************************************************************** + +The Composable Kernel library uses a tile-based programming model and tensor coordinate transformation to achieve performance portability and code maintainability. Tensor coordinate transformation is a complexity reduction technique for complex machine learning operators. + + +.. image:: ../data/ck_component.png + :alt: CK Components + + +The Composable Kernel library consists of four layers: + +* a templated tile operator layer +* a templated kernel and invoker layer +* an instantiated kernel and invoker layer +* a client API layer. + +A wrapper component is included to simplify tensor transform operations. + +.. image:: ../data/ck_layer.png + :alt: CK Layers + \ No newline at end of file diff --git a/docs/conceptual/what-is-ck.rst b/docs/conceptual/what-is-ck.rst deleted file mode 100644 index 36785fc6ca..0000000000 --- a/docs/conceptual/what-is-ck.rst +++ /dev/null @@ -1,41 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _what-is-ck: - -******************************************************************** -What is the Composable Kernel library -******************************************************************** - - -Methodology -=========== - -The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. - -CK utilizes two concepts to achieve performance portability and code maintainability: - -* A tile-based programming model -* Algorithm complexity reduction for complex ML operators using an innovative technique called - "Tensor Coordinate Transformation". - -.. image:: ../data/ck_component.png - :alt: CK Components - - -Code Structure -============== - -The CK library is structured into 4 layers: - -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Client API" layer - -It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code. - -.. image:: ../data/ck_layer.png - :alt: CK Layers - \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e8617a09ef..fe8a1c1d79 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,6 +28,7 @@ external_toc_path = "./sphinx/_toc.yml" docs_core = ROCmDocs(left_nav_title) docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") +docs_core.enable_api_reference() docs_core.setup() external_projects_current_project = "composable_kernel" diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index fac9e138e1..4c8019f8d3 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -1,4 +1,4 @@ -# Doxyfile 1.8.10 +# Doxyfile 1.9.7 # This file describes the settings to be used by the documentation system # doxygen (www.doxygen.org) for a project. @@ -12,16 +12,26 @@ # For lists, items can also be appended using: # TAG += value [value, ...] # Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables or CMake type +# replacement variables: +# doxygen -x_noenv [configFile] #--------------------------------------------------------------------------- # Project related configuration options #--------------------------------------------------------------------------- -# This tag specifies the encoding used for all characters in the config file -# that follow. The default is UTF-8 which is also the encoding used for all text -# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv -# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv -# for the list of possible encodings. +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. # The default value is: UTF-8. DOXYFILE_ENCODING = UTF-8 @@ -32,26 +42,26 @@ DOXYFILE_ENCODING = UTF-8 # title of most generated pages and in a few other places. # The default value is: My Project. -PROJECT_NAME = "ck" +PROJECT_NAME = "Composable Kernel" # The PROJECT_NUMBER tag can be used to enter a project or revision number. This # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = v3.0.1.0 +PROJECT_NUMBER = # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a # quick idea about the purpose of the project. Keep the description short. -PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and HiP" +PROJECT_BRIEF = "Prototype interfaces compatible with ROCm platform and HiP" # With the PROJECT_LOGO tag one can specify a logo or an icon that is included # in the documentation. The maximum height of the logo should not exceed 55 # pixels and the maximum width should not exceed 200 pixels. Doxygen will copy # the logo to the output directory. -PROJECT_LOGO = +PROJECT_LOGO = # The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path # into which the generated documentation will be written. If a relative path is @@ -60,16 +70,28 @@ PROJECT_LOGO = OUTPUT_DIRECTORY = . -# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this # option can be useful when feeding doxygen a huge amount of source files, where # putting all generated files in the same directory would otherwise causes -# performance problems for the file system. +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. # The default value is: NO. CREATE_SUBDIRS = NO +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# number of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + # If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII # characters to appear in the names of generated files. If set to NO, non-ASCII # characters will be escaped, for example _xE3_x81_x84 will be used for Unicode @@ -81,14 +103,14 @@ ALLOW_UNICODE_NAMES = NO # The OUTPUT_LANGUAGE tag is used to specify the language in which all # documentation generated by doxygen is written. Doxygen will use this # information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, -# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), -# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, -# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), -# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, -# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, -# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, -# Ukrainian and Vietnamese. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. # The default value is: English. OUTPUT_LANGUAGE = English @@ -162,7 +184,8 @@ FULL_PATH_NAMES = YES # will be relative from the directory where doxygen is started. # This tag requires that the tag FULL_PATH_NAMES is set to YES. -STRIP_FROM_PATH = +#STRIP_FROM_PATH = +STRIP_FROM_PATH = /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/latest/ # The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the # path mentioned in the documentation of a class, which tells the reader which @@ -171,7 +194,8 @@ STRIP_FROM_PATH = # specify the list of include paths that are normally passed to the compiler # using the -I flag. -STRIP_FROM_INC_PATH = +STRIP_FROM_INC_PATH = + # If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but # less readable) file names. This can be useful is your file systems doesn't @@ -189,6 +213,16 @@ SHORT_NAMES = NO JAVADOC_AUTOBRIEF = NO +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + +JAVADOC_BANNER = NO + # If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first # line (until the first dot) of a Qt-style comment as the brief description. If # set to NO, the Qt-style will behave just like regular Qt-style comments (thus @@ -209,6 +243,14 @@ QT_AUTOBRIEF = NO MULTILINE_CPP_IS_BRIEF = NO +# By default Python docstrings are displayed as preformatted text and doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as doxygen documentation. +# The default value is: YES. + +PYTHON_DOCSTRING = YES + # If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the # documentation from any documented member that it re-implements. # The default value is: YES. @@ -232,20 +274,19 @@ TAB_SIZE = 4 # the documentation. An alias has the form: # name=value # For example adding -# "sideeffect=@par Side Effects:\n" +# "sideeffect=@par Side Effects:^^" # will allow you to put the command \sideeffect (or @sideeffect) in the # documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". You can put \n's in the value part of an alias to insert -# newlines. +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) ALIASES = -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - # Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources # only. Doxygen will then generate output that is more tailored for C. For # instance, some of the names that are used will be different. The list of all @@ -274,28 +315,40 @@ OPTIMIZE_FOR_FORTRAN = NO OPTIMIZE_OUTPUT_VHDL = NO +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_SLICE = NO + # Doxygen selects the parser to use depending on the extension of the files it # parses. With this tag you can assign which parser to use for a given # extension. Doxygen has a built-in mapping, but you can override or extend it # using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, Javascript, -# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: -# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: -# Fortran. In the later case the parser tries to guess whether the code is fixed -# or free formatted code, this is the default for Fortran type files), VHDL. For -# instance to make doxygen treat .inc files as Fortran files (default is PHP), -# and .f files as C (default is Fortran), use: inc=Fortran f=C. +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. # # Note: For files without extension you can use no_extension as a placeholder. # # Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. +# the files are not read by doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. EXTENSION_MAPPING = # If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments # according to the Markdown format, which allows for more readable -# documentation. See http://daringfireball.net/projects/markdown/ for details. +# documentation. See https://daringfireball.net/projects/markdown/ for details. # The output of markdown processing is further processed by doxygen, so you can # mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in # case of backward compatibilities issues. @@ -303,6 +356,26 @@ EXTENSION_MAPPING = MARKDOWN_SUPPORT = YES +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +TOC_INCLUDE_HEADINGS = 5 + +# The MARKDOWN_ID_STYLE tag can be used to specify the algorithm used to +# generate identifiers for the Markdown headings. Note: Every identifier is +# unique. +# Possible values are: DOXYGEN Use a fixed 'autotoc_md' string followed by a +# sequence number starting at 0. and GITHUB Use the lower case version of title +# with any whitespace replaced by '-' and punctations characters removed.. +# The default value is: DOXYGEN. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +MARKDOWN_ID_STYLE = DOXYGEN + # When enabled doxygen tries to link words that correspond to documented # classes, or namespaces to their corresponding documentation. Such a link can # be prevented in individual cases by putting a % sign in front of the word or @@ -328,7 +401,7 @@ BUILTIN_STL_SUPPORT = YES CPP_CLI_SUPPORT = NO # Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen # will parse them like normal C++ but will assume all classes use public instead # of private inheritance when no explicit protection keyword is present. # The default value is: NO. @@ -414,6 +487,27 @@ TYPEDEF_HIDES_STRUCT = YES LOOKUP_CACHE_SIZE = 0 +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use +# during processing. When set to 0 doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + +NUM_PROC_THREADS = 1 + +# If the TIMESTAMP tag is set different from NO then each generated page will +# contain the date or date and time when the page was generated. Setting this to +# NO can help when comparing the output of multiple runs. +# Possible values are: YES, NO, DATETIME and DATE. +# The default value is: NO. + +TIMESTAMP = YES + #--------------------------------------------------------------------------- # Build related configuration options #--------------------------------------------------------------------------- @@ -434,6 +528,12 @@ EXTRACT_ALL = YES EXTRACT_PRIVATE = NO +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIV_VIRTUAL = NO + # If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal # scope will be included in the documentation. # The default value is: NO. @@ -471,6 +571,13 @@ EXTRACT_LOCAL_METHODS = NO EXTRACT_ANON_NSPACES = NO +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + # If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all # undocumented members inside documented classes or files. If set to NO these # members will be included in the various overviews, but no documentation @@ -482,14 +589,15 @@ HIDE_UNDOC_MEMBERS = NO # If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all # undocumented classes that are normally visible in the class hierarchy. If set # to NO, these classes will be included in the various overviews. This option -# has no effect if EXTRACT_ALL is enabled. +# will also hide undocumented C++ concepts if enabled. This option has no effect +# if EXTRACT_ALL is enabled. # The default value is: NO. HIDE_UNDOC_CLASSES = NO # If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# (class|struct|union) declarations. If set to NO, these declarations will be -# included in the documentation. +# declarations. If set to NO, these declarations will be included in the +# documentation. # The default value is: NO. HIDE_FRIEND_COMPOUNDS = NO @@ -508,12 +616,20 @@ HIDE_IN_BODY_DOCS = NO INTERNAL_DOCS = NO -# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file -# names in lower-case letters. If set to YES, upper-case letters are also -# allowed. This is useful if you have classes or files whose names only differ -# in case and if your file system supports case sensitive file names. Windows -# and Mac users are advised to set this option to NO. -# The default value is: system dependent. +# With the correct setting of option CASE_SENSE_NAMES doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and MacOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# Possible values are: SYSTEM, NO and YES. +# The default value is: SYSTEM. CASE_SENSE_NAMES = NO @@ -531,6 +647,12 @@ HIDE_SCOPE_NAMES = NO HIDE_COMPOUND_REFERENCE= NO +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + # If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of # the files that are included by a file in the documentation of that file. # The default value is: YES. @@ -688,7 +810,8 @@ FILE_VERSION_FILTER = # output files in an output format independent way. To create the layout file # that represents doxygen's defaults, run doxygen with the -l option. You can # optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. # # Note that if you run doxygen from a directory containing a file called # DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE @@ -699,7 +822,7 @@ LAYOUT_FILE = # The CITE_BIB_FILES tag can be used to specify one or more bib files containing # the reference definitions. This must be a list of .bib files. The .bib # extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. # For LaTeX the style of the bibliography can be controlled using # LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the # search path. See also \cite for info how to create references. @@ -734,34 +857,81 @@ WARNINGS = YES WARN_IF_UNDOCUMENTED = YES # If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as not documenting some parameters -# in a documented function, or documenting parameters that don't exist or using -# markup commands wrongly. +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. # The default value is: YES. WARN_IF_DOC_ERROR = YES +# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete +# function parameter documentation. If set to NO, doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + # This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that # are documented, but have no documentation for their parameters or return -# value. If set to NO, doxygen will only warn about wrong or incomplete -# parameter documentation, but not about the absence of documentation. +# value. If set to NO, doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC # The default value is: NO. WARN_NO_PARAMDOC = NO +# If WARN_IF_UNDOC_ENUM_VAL option is set to YES, doxygen will warn about +# undocumented enumeration values. If set to NO, doxygen will accept +# undocumented enumeration values. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: NO. + +WARN_IF_UNDOC_ENUM_VAL = NO + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the doxygen process doxygen will return with a non-zero status. +# If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS_PRINT then doxygen behaves +# like FAIL_ON_WARNINGS but in case no WARN_LOGFILE is defined doxygen will not +# write the warning messages in between other messages but write them at the end +# of a run, in case a WARN_LOGFILE is defined the warning messages will be +# besides being in the defined file also be shown at the end of a run, unless +# the WARN_LOGFILE is defined as - i.e. standard output (stdout) in that case +# the behavior will remain as with the setting FAIL_ON_WARNINGS. +# Possible values are: NO, YES, FAIL_ON_WARNINGS and FAIL_ON_WARNINGS_PRINT. +# The default value is: NO. + +WARN_AS_ERROR = NO + # The WARN_FORMAT tag determines the format of the warning messages that doxygen # can produce. The string should contain the $file, $line, and $text tags, which # will be replaced by the file and line number from which the warning originated # and the warning text. Optionally the format may contain $version, which will # be replaced by the version of the file (if it could be obtained via # FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT # The default value is: $file:$line: $text. WARN_FORMAT = "$file:$line: $text" +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + # The WARN_LOGFILE tag can be used to specify a file to which warning and error # messages should be written. If left blank the output is written to standard -# error (stderr). +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). WARN_LOGFILE = @@ -775,22 +945,31 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../../include/ck/tensor_operation/gpu/grid \ - ../../include/ck/tensor_operation/gpu/block \ - ../../include/ck/tensor_operation/gpu/thread \ +INPUT = ../../include \ + ../../include/ck/ \ ../../library/include/ck/library/utility \ - ../../include/ck/wrapper - + ../../include/ck_tile # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses # libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: http://www.gnu.org/software/libiconv) for the list of -# possible encodings. +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# See also: INPUT_FILE_ENCODING # The default value is: UTF-8. INPUT_ENCODING = UTF-8 +# This tag can be used to specify the character encoding of the source files +# that doxygen parses The INPUT_FILE_ENCODING tag can be used to specify +# character encoding on a per file pattern basis. Doxygen will compare the file +# name with each pattern and apply the encoding instead of the default +# INPUT_ENCODING) if there is a match. The character encodings are a list of the +# form: pattern=encoding (like *.php=ISO-8859-1). See cfg_input_encoding +# "INPUT_ENCODING" for further information on supported encodings. + +INPUT_FILE_ENCODING = + # If the value of the INPUT tag contains directories, you can use the # FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and # *.h) to filter out the source-files in the directories. @@ -799,11 +978,15 @@ INPUT_ENCODING = UTF-8 # need to set EXTENSION_MAPPING for the extension otherwise the files are not # read by doxygen. # +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# # If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, # *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, -# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, -# *.m, *.markdown, *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, -# *.vhdl, *.ucf, *.qsf, *.as and *.js. +# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, +# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C +# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. FILE_PATTERNS = *.c \ *.cc \ @@ -824,6 +1007,7 @@ FILE_PATTERNS = *.c \ *.hxx \ *.hpp \ *.h++ \ + *.l \ *.cs \ *.d \ *.php \ @@ -837,13 +1021,19 @@ FILE_PATTERNS = *.c \ *.mm \ *.dox \ *.py \ - *.tcl \ + *.pyw \ + *.f90 \ + *.f95 \ + *.f03 \ + *.f08 \ + *.f18 \ + *.f \ + *.for \ *.vhd \ *.vhdl \ *.ucf \ *.qsf \ - *.as \ - *.js + *.ice # The RECURSIVE tag can be used to specify whether or not subdirectories should # be searched for input files as well. @@ -880,10 +1070,7 @@ EXCLUDE_PATTERNS = # (namespaces, classes, functions, etc.) that should be excluded from the # output. The symbol name can be a fully qualified name, a word, or if the # wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* +# ANamespace::AClass, ANamespace::*Test EXCLUDE_SYMBOLS = @@ -927,6 +1114,15 @@ IMAGE_PATH = # Note that the filter must not add or remove lines; it is applied before the # code is scanned, but not when the output code is generated. If lines are added # or removed, the anchors will not be placed correctly. +# +# Note that doxygen will use the data processed and written to standard output +# for further processing, therefore nothing else, like debug statements or used +# commands (so in case of a Windows batch file always use @echo OFF), should be +# written to standard output. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. INPUT_FILTER = @@ -936,6 +1132,10 @@ INPUT_FILTER = # (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how # filters are used. If the FILTER_PATTERNS tag is empty or if none of the # patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. FILTER_PATTERNS = @@ -959,7 +1159,17 @@ FILTER_SOURCE_PATTERNS = # (index.html). This can be useful if you have a project on for instance GitHub # and want to reuse the introduction page also for the doxygen output. -USE_MDFILE_AS_MAINPAGE = ../README.md + +USE_MDFILE_AS_MAINPAGE = + +# The Fortran standard specifies that for fixed formatted Fortran code all +# characters from position 72 are to be considered as comment. A common +# extension is to allow longer lines before the automatic comment starts. The +# setting FORTRAN_COMMENT_AFTER will also make it possible that longer lines can +# be processed before the automatic comment starts. +# Minimum value: 7, maximum value: 10000, default value: 72. + +FORTRAN_COMMENT_AFTER = 72 #--------------------------------------------------------------------------- # Configuration options related to source browsing @@ -988,7 +1198,7 @@ INLINE_SOURCES = NO STRIP_CODE_COMMENTS = YES # If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# function all documented functions referencing it will be listed. +# entity all documented functions referencing it will be listed. # The default value is: NO. REFERENCED_BY_RELATION = NO @@ -1020,12 +1230,12 @@ SOURCE_TOOLTIPS = YES # If the USE_HTAGS tag is set to YES then the references to source code will # point to the HTML generated by the htags(1) tool instead of doxygen built-in # source browser. The htags tool is part of GNU's global source tagging system -# (see http://www.gnu.org/software/global/global.html). You will need version +# (see https://www.gnu.org/software/global/global.html). You will need version # 4.8.6 or higher. # # To use it do the following: # - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the config file +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file # - Make sure the INPUT points to the root of the source tree # - Run doxygen as normal # @@ -1047,25 +1257,6 @@ USE_HTAGS = NO VERBATIM_HEADERS = YES -# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the -# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the -# cost of reduced performance. This can be particularly helpful with template -# rich C++ code for which doxygen's built-in parser lacks the necessary type -# information. -# Note: The availability of this option depends on whether or not doxygen was -# compiled with the --with-libclang option. -# The default value is: NO. - -CLANG_ASSISTED_PARSING = NO - -# If clang assisted parsing is enabled you can provide the compiler with command -# line options that you would normally use when invoking the compiler. Note that -# the include paths will already be set by doxygen for the files and directories -# specified with INPUT and INCLUDE_PATH. -# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. - -CLANG_OPTIONS = - #--------------------------------------------------------------------------- # Configuration options related to the alphabetical class index #--------------------------------------------------------------------------- @@ -1077,17 +1268,11 @@ CLANG_OPTIONS = ALPHABETICAL_INDEX = YES -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. +# The IGNORE_PREFIX tag can be used to specify a prefix (or a list of prefixes) +# that should be ignored while generating the index headers. The IGNORE_PREFIX +# tag works for classes, function and member names. The entity will be placed in +# the alphabetical list under the first letter of the entity name that remains +# after removing the prefix. # This tag requires that the tag ALPHABETICAL_INDEX is set to YES. IGNORE_PREFIX = @@ -1134,7 +1319,7 @@ HTML_FILE_EXTENSION = .html # of the possible markers and block names see the documentation. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_HEADER = +HTML_HEADER = ../_doxygen/header.html # The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each # generated HTML page. If the tag is left blank doxygen will generate a standard @@ -1144,7 +1329,7 @@ HTML_HEADER = # that doxygen normally uses. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_FOOTER = +HTML_FOOTER = ../_doxygen/footer.html # The HTML_STYLESHEET tag can be used to specify a user-defined cascading style # sheet that is used by each HTML page. It can be used to fine-tune the look of @@ -1156,7 +1341,7 @@ HTML_FOOTER = # obsolete. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_STYLESHEET = +HTML_STYLESHEET = ../_doxygen/stylesheet.css # The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined # cascading style sheets that are included after the standard style sheets @@ -1166,10 +1351,15 @@ HTML_STYLESHEET = # Doxygen will copy the style sheet files to the output directory. # Note: The order of the extra style sheet files is of importance (e.g. the last # style sheet in the list overrules the setting of the previous ones in the -# list). For an example see the documentation. +# list). +# Note: Since the styling of scrollbars can currently not be overruled in +# Webkit/Chromium, the styling will be left out of the default doxygen.css if +# one or more extra stylesheets have been specified. So if scrollbar +# customization is desired it has to be added explicitly. For an example see the +# documentation. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_EXTRA_STYLESHEET = +HTML_EXTRA_STYLESHEET = ../_doxygen/extra_stylesheet.css # The HTML_EXTRA_FILES tag can be used to specify one or more extra images or # other source files which should be copied to the HTML output directory. Note @@ -1179,21 +1369,34 @@ HTML_EXTRA_STYLESHEET = # files will be copied as-is; there are no commands or markers available. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_EXTRA_FILES = +HTML_EXTRA_FILES = ../_doxygen/extra_stylesheet.css + +# The HTML_COLORSTYLE tag can be used to specify if the generated HTML output +# should be rendered with a dark or light theme. +# Possible values are: LIGHT always generate light mode output, DARK always +# generate dark mode output, AUTO_LIGHT automatically set the mode according to +# the user preference, use light mode if no preference is set (the default), +# AUTO_DARK automatically set the mode according to the user preference, use +# dark mode if no preference is set and TOGGLE allow to user to switch between +# light and dark mode via a button. +# The default value is: AUTO_LIGHT. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE = LIGHT # The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen # will adjust the colors in the style sheet and background images according to -# this color. Hue is specified as an angle on a colorwheel, see -# http://en.wikipedia.org/wiki/Hue for more information. For instance the value +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value # 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 # purple, and 360 is red again. # Minimum value: 0, maximum value: 359, default value: 220. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_COLORSTYLE_HUE = 220 +HTML_COLORSTYLE_HUE = 240 # The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use grayscales only. A +# in the HTML output. For a value of 0 the output will use gray-scales only. A # value of 255 will produce the most vivid colors. # Minimum value: 0, maximum value: 255, default value: 100. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1211,14 +1414,16 @@ HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_TIMESTAMP = NO +HTML_DYNAMIC_MENUS = YES # If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML # documentation will contain sections that can be hidden and shown after the @@ -1243,13 +1448,14 @@ HTML_INDEX_NUM_ENTRIES = 100 # If the GENERATE_DOCSET tag is set to YES, additional index files will be # generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: http://developer.apple.com/tools/xcode/), introduced with -# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a -# Makefile in the HTML output directory. Running make will produce the docset in -# that directory and running make install will install the docset in +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in # ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html -# for more information. +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1263,6 +1469,13 @@ GENERATE_DOCSET = NO DOCSET_FEEDNAME = "Doxygen generated docs" +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + # This tag specifies a string that should uniquely identify the documentation # set bundle. This should be a reverse domain-name style string, e.g. # com.mycompany.MyDocSet. Doxygen will append .docset to the name. @@ -1288,8 +1501,12 @@ DOCSET_PUBLISHER_NAME = Publisher # If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three # additional HTML index files: index.hhp, index.hhc, and index.hhk. The # index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on -# Windows. +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). # # The HTML Help Workshop contains a compiler that can convert all HTML output # generated by doxygen into a single compiled HTML file (.chm). Compiled HTML @@ -1319,7 +1536,7 @@ CHM_FILE = HHC_LOCATION = # The GENERATE_CHI flag controls if a separate .chi index file is generated -# (YES) or that it should be included in the master .chm file (NO). +# (YES) or that it should be included in the main .chm file (NO). # The default value is: NO. # This tag requires that the tag GENERATE_HTMLHELP is set to YES. @@ -1346,6 +1563,16 @@ BINARY_TOC = NO TOC_EXPAND = NO +# The SITEMAP_URL tag is used to specify the full URL of the place where the +# generated documentation will be placed on the server by the user during the +# deployment of the documentation. The generated sitemap is called sitemap.xml +# and placed on the directory specified by HTML_OUTPUT. In case no SITEMAP_URL +# is specified no sitemap is generated. For information about the sitemap +# protocol see https://www.sitemaps.org +# This tag requires that the tag GENERATE_HTML is set to YES. + +SITEMAP_URL = + # If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and # QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that # can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help @@ -1364,7 +1591,8 @@ QCH_FILE = # The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help # Project output. For more information please see Qt Help Project / Namespace -# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). # The default value is: org.doxygen.Project. # This tag requires that the tag GENERATE_QHP is set to YES. @@ -1372,8 +1600,8 @@ QHP_NAMESPACE = org.doxygen.Project # The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt # Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- -# folders). +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). # The default value is: doc. # This tag requires that the tag GENERATE_QHP is set to YES. @@ -1381,30 +1609,30 @@ QHP_VIRTUAL_FOLDER = doc # If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom # filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_CUST_FILTER_NAME = # The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the # custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_CUST_FILTER_ATTRS = # The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this # project's filter section matches. Qt Help Project / Filter Attributes (see: -# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_SECT_FILTER_ATTRS = -# The QHG_LOCATION tag can be used to specify the location of Qt's -# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the -# generated .qhp file. +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to +# run qhelpgenerator on the generated .qhp file. # This tag requires that the tag GENERATE_QHP is set to YES. QHG_LOCATION = @@ -1447,16 +1675,28 @@ DISABLE_INDEX = NO # to work a browser that supports JavaScript, DHTML, CSS and frames is required # (i.e. any modern browser). Windows users are probably better off using the # HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can -# further fine-tune the look of the index. As an example, the default style -# sheet generated by doxygen has an example that shows how to put an image at -# the root of the tree instead of the PROJECT_NAME. Since the tree basically has -# the same information as the tab index, you could consider setting -# DISABLE_INDEX to YES when enabling this option. +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. GENERATE_TREEVIEW = NO +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + # The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that # doxygen will group on one line in the generated HTML documentation. # @@ -1481,6 +1721,24 @@ TREEVIEW_WIDTH = 250 EXT_LINKS_IN_WINDOW = NO +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FORMULA_FORMAT = png + # Use this tag to change the font size of LaTeX formulas included as images in # the HTML documentation. When you change the font size after a successful # doxygen run you need to manually remove any form_*.png images from the HTML @@ -1490,19 +1748,14 @@ EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. -FORMULA_TRANSPARENT = YES +FORMULA_MACROFILE = # Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# http://www.mathjax.org) which uses client side Javascript for the rendering +# https://www.mathjax.org) which uses client side JavaScript for the rendering # instead of using pre-rendered bitmaps. Use this if you do not have LaTeX # installed or if you want to formulas look prettier in the HTML output. When # enabled you may also need to install MathJax separately and configure the path @@ -1512,11 +1765,29 @@ FORMULA_TRANSPARENT = YES USE_MATHJAX = YES +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + # When MathJax is enabled you can set the default output format to be used for -# the MathJax output. See the MathJax site (see: -# http://docs.mathjax.org/en/latest/output.html) for more details. +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). # Possible values are: HTML-CSS (which is slower, but has the best -# compatibility), NativeMML (i.e. MathML) and SVG. +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. # The default value is: HTML-CSS. # This tag requires that the tag USE_MATHJAX is set to YES. @@ -1529,22 +1800,29 @@ MATHJAX_FORMAT = HTML-CSS # MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax # Content Delivery Network so you can quickly see the result without installing # MathJax. However, it is strongly recommended to install a local copy of -# MathJax from http://www.mathjax.org before deployment. -# The default value is: http://cdn.mathjax.org/mathjax/latest. +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 # This tag requires that the tag USE_MATHJAX is set to YES. -MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest +MATHJAX_RELPATH = # The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax # extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see +# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): # MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams # This tag requires that the tag USE_MATHJAX is set to YES. MATHJAX_EXTENSIONS = # The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces # of code that will be used on startup of the MathJax code. See the MathJax site -# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an # example see the documentation. # This tag requires that the tag USE_MATHJAX is set to YES. @@ -1569,10 +1847,10 @@ MATHJAX_CODEFILE = # The default value is: YES. # This tag requires that the tag GENERATE_HTML is set to YES. -SEARCHENGINE = YES +SEARCHENGINE = NO # When the SERVER_BASED_SEARCH tag is enabled the search engine will be -# implemented using a web server instead of a web client using Javascript. There +# implemented using a web server instead of a web client using JavaScript. There # are two flavors of web server based searching depending on the EXTERNAL_SEARCH # setting. When disabled, doxygen will generate a PHP script for searching and # an index file used by the script. When EXTERNAL_SEARCH is enabled the indexing @@ -1591,7 +1869,8 @@ SERVER_BASED_SEARCH = NO # # Doxygen ships with an example indexer (doxyindexer) and search engine # (doxysearch.cgi) which are based on the open source search engine library -# Xapian (see: http://xapian.org/). +# Xapian (see: +# https://xapian.org/). # # See the section "External Indexing and Searching" for details. # The default value is: NO. @@ -1604,8 +1883,9 @@ EXTERNAL_SEARCH = NO # # Doxygen ships with an example indexer (doxyindexer) and search engine # (doxysearch.cgi) which are based on the open source search engine library -# Xapian (see: http://xapian.org/). See the section "External Indexing and -# Searching" for details. +# Xapian (see: +# https://xapian.org/). See the section "External Indexing and Searching" for +# details. # This tag requires that the tag SEARCHENGINE is set to YES. SEARCHENGINE_URL = @@ -1656,21 +1936,35 @@ LATEX_OUTPUT = latex # The LATEX_CMD_NAME tag can be used to specify the LaTeX command name to be # invoked. # -# Note that when enabling USE_PDFLATEX this option is only used for generating -# bitmaps for formulas in the HTML output, but not in the Makefile that is -# written to the output directory. -# The default file is: latex. +# Note that when not enabling USE_PDFLATEX the default is latex when enabling +# USE_PDFLATEX the default is pdflatex and when in the later case latex is +# chosen this is overwritten by pdflatex. For specific output languages the +# default can have been set differently, this depends on the implementation of +# the output language. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_CMD_NAME = latex # The MAKEINDEX_CMD_NAME tag can be used to specify the command name to generate # index for LaTeX. +# Note: This tag is used in the Makefile / make.bat. +# See also: LATEX_MAKEINDEX_CMD for the part in the generated output file +# (.tex). # The default file is: makeindex. # This tag requires that the tag GENERATE_LATEX is set to YES. MAKEINDEX_CMD_NAME = makeindex +# The LATEX_MAKEINDEX_CMD tag can be used to specify the command name to +# generate index for LaTeX. In case there is no backslash (\) as first character +# it will be automatically added in the LaTeX code. +# Note: This tag is used in the generated output file (.tex). +# See also: MAKEINDEX_CMD_NAME for the part in the Makefile / make.bat. +# The default value is: makeindex. +# This tag requires that the tag GENERATE_LATEX is set to YES. + +LATEX_MAKEINDEX_CMD = makeindex + # If the COMPACT_LATEX tag is set to YES, doxygen generates more compact LaTeX # documents. This may be useful for small projects and may help to save some # trees in general. @@ -1700,29 +1994,31 @@ PAPER_TYPE = a4 EXTRA_PACKAGES = -# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the -# generated LaTeX document. The header should contain everything until the first -# chapter. If it is left blank doxygen will generate a standard header. See -# section "Doxygen usage" for information on how to let doxygen write the -# default header to a separate file. +# The LATEX_HEADER tag can be used to specify a user-defined LaTeX header for +# the generated LaTeX document. The header should contain everything until the +# first chapter. If it is left blank doxygen will generate a standard header. It +# is highly recommended to start with a default header using +# doxygen -w latex new_header.tex new_footer.tex new_stylesheet.sty +# and then modify the file new_header.tex. See also section "Doxygen usage" for +# information on how to generate the default header that doxygen normally uses. # -# Note: Only use a user-defined header if you know what you are doing! The -# following commands have a special meaning inside the header: $title, -# $datetime, $date, $doxygenversion, $projectname, $projectnumber, -# $projectbrief, $projectlogo. Doxygen will replace $title with the empty -# string, for the replacement values of the other commands the user is referred -# to HTML_HEADER. +# Note: Only use a user-defined header if you know what you are doing! +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. The following +# commands have a special meaning inside the header (and footer): For a +# description of the possible markers and block names see the documentation. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_HEADER = -# The LATEX_FOOTER tag can be used to specify a personal LaTeX footer for the -# generated LaTeX document. The footer should contain everything after the last -# chapter. If it is left blank doxygen will generate a standard footer. See +# The LATEX_FOOTER tag can be used to specify a user-defined LaTeX footer for +# the generated LaTeX document. The footer should contain everything after the +# last chapter. If it is left blank doxygen will generate a standard footer. See # LATEX_HEADER for more information on how to generate a default footer and what -# special commands can be used inside the footer. -# -# Note: Only use a user-defined footer if you know what you are doing! +# special commands can be used inside the footer. See also section "Doxygen +# usage" for information on how to generate the default footer that doxygen +# normally uses. Note: Only use a user-defined footer if you know what you are +# doing! # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_FOOTER = @@ -1755,18 +2051,26 @@ LATEX_EXTRA_FILES = PDF_HYPERLINKS = YES -# If the USE_PDFLATEX tag is set to YES, doxygen will use pdflatex to generate -# the PDF file directly from the LaTeX files. Set this option to YES, to get a -# higher quality PDF documentation. +# If the USE_PDFLATEX tag is set to YES, doxygen will use the engine as +# specified with LATEX_CMD_NAME to generate the PDF file directly from the LaTeX +# files. Set this option to YES, to get a higher quality PDF documentation. +# +# See also section LATEX_CMD_NAME for selecting the engine. # The default value is: YES. # This tag requires that the tag GENERATE_LATEX is set to YES. USE_PDFLATEX = YES -# If the LATEX_BATCHMODE tag is set to YES, doxygen will add the \batchmode -# command to the generated LaTeX files. This will instruct LaTeX to keep running -# if errors occur, instead of asking the user for help. This option is also used -# when generating formulas in HTML. +# The LATEX_BATCHMODE tag ignals the behavior of LaTeX in case of an error. +# Possible values are: NO same as ERROR_STOP, YES same as BATCH, BATCH In batch +# mode nothing is printed on the terminal, errors are scrolled as if is +# hit at every error; missing files that TeX tries to input or request from +# keyboard input (\read on a not open input stream) cause the job to abort, +# NON_STOP In nonstop mode the diagnostic message will appear on the terminal, +# but there is no possibility of user interaction just like in batch mode, +# SCROLL In scroll mode, TeX will stop only for missing files to input or if +# keyboard input is necessary and ERROR_STOP In errorstop mode, TeX will stop at +# each error, asking for user intervention. # The default value is: NO. # This tag requires that the tag GENERATE_LATEX is set to YES. @@ -1779,24 +2083,22 @@ LATEX_BATCHMODE = NO LATEX_HIDE_INDICES = NO -# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source -# code with syntax highlighting in the LaTeX output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_SOURCE_CODE = NO - # The LATEX_BIB_STYLE tag can be used to specify the style to use for the # bibliography, e.g. plainnat, or ieeetr. See -# http://en.wikipedia.org/wiki/BibTeX and \cite for more info. +# https://en.wikipedia.org/wiki/BibTeX and \cite for more info. # The default value is: plain. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_BIB_STYLE = plain +# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute) +# path from which the emoji images will be read. If a relative path is entered, +# it will be relative to the LATEX_OUTPUT directory. If left blank the +# LATEX_OUTPUT directory will be used. +# This tag requires that the tag GENERATE_LATEX is set to YES. + +LATEX_EMOJI_DIRECTORY = + #--------------------------------------------------------------------------- # Configuration options related to the RTF output #--------------------------------------------------------------------------- @@ -1836,9 +2138,9 @@ COMPACT_RTF = NO RTF_HYPERLINKS = NO -# Load stylesheet definitions from file. Syntax is similar to doxygen's config -# file, i.e. a series of assignments. You only have to provide replacements, -# missing definitions are set to their default value. +# Load stylesheet definitions from file. Syntax is similar to doxygen's +# configuration file, i.e. a series of assignments. You only have to provide +# replacements, missing definitions are set to their default value. # # See also section "Doxygen usage" for information on how to generate the # default style sheet that doxygen normally uses. @@ -1847,22 +2149,12 @@ RTF_HYPERLINKS = NO RTF_STYLESHEET_FILE = # Set optional variables used in the generation of an RTF document. Syntax is -# similar to doxygen's config file. A template extensions file can be generated -# using doxygen -e rtf extensionFile. +# similar to doxygen's configuration file. A template extensions file can be +# generated using doxygen -e rtf extensionFile. # This tag requires that the tag GENERATE_RTF is set to YES. RTF_EXTENSIONS_FILE = -# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code -# with syntax highlighting in the RTF output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_RTF is set to YES. - -RTF_SOURCE_CODE = NO - #--------------------------------------------------------------------------- # Configuration options related to the man page output #--------------------------------------------------------------------------- @@ -1934,6 +2226,13 @@ XML_OUTPUT = xml XML_PROGRAMLISTING = YES +# If the XML_NS_MEMB_FILE_SCOPE tag is set to YES, doxygen will include +# namespace members in file scope as well, matching the HTML output. +# The default value is: NO. +# This tag requires that the tag GENERATE_XML is set to YES. + +XML_NS_MEMB_FILE_SCOPE = NO + #--------------------------------------------------------------------------- # Configuration options related to the DOCBOOK output #--------------------------------------------------------------------------- @@ -1952,23 +2251,14 @@ GENERATE_DOCBOOK = NO DOCBOOK_OUTPUT = docbook -# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the -# program listings (including syntax highlighting and cross-referencing -# information) to the DOCBOOK output. Note that enabling this will significantly -# increase the size of the DOCBOOK output. -# The default value is: NO. -# This tag requires that the tag GENERATE_DOCBOOK is set to YES. - -DOCBOOK_PROGRAMLISTING = NO - #--------------------------------------------------------------------------- # Configuration options for the AutoGen Definitions output #--------------------------------------------------------------------------- # If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an -# AutoGen Definitions (see http://autogen.sf.net) file that captures the -# structure of the code including all documentation. Note that this feature is -# still experimental and incomplete at the moment. +# AutoGen Definitions (see https://autogen.sourceforge.net/) file that captures +# the structure of the code including all documentation. Note that this feature +# is still experimental and incomplete at the moment. # The default value is: NO. GENERATE_AUTOGEN_DEF = NO @@ -2047,7 +2337,8 @@ SEARCH_INCLUDES = NO # The INCLUDE_PATH tag can be used to specify one or more directories that # contain include files that are not input files but should be processed by the -# preprocessor. +# preprocessor. Note that the INCLUDE_PATH is not recursive, so the setting of +# RECURSIVE has no effect here. # This tag requires that the tag SEARCH_INCLUDES is set to YES. INCLUDE_PATH = @@ -2113,7 +2404,7 @@ TAGFILES = # tag file that is based on the input files it reads. See section "Linking to # external documentation" for more information about the usage of tag files. -GENERATE_TAGFILE = +GENERATE_TAGFILE = html/tagfile.xml # If the ALLEXTERNALS tag is set to YES, all external class will be listed in # the class index. If set to NO, only the inherited external classes will be @@ -2136,41 +2427,10 @@ EXTERNAL_GROUPS = YES EXTERNAL_PAGES = YES -# The PERL_PATH should be the absolute path and name of the perl script -# interpreter (i.e. the result of 'which perl'). -# The default file (with absolute path) is: /usr/bin/perl. - -PERL_PATH = /usr/bin/perl - #--------------------------------------------------------------------------- -# Configuration options related to the dot tool +# Configuration options related to diagram generator tools #--------------------------------------------------------------------------- -# If the CLASS_DIAGRAMS tag is set to YES, doxygen will generate a class diagram -# (in HTML and LaTeX) for classes with base or super classes. Setting the tag to -# NO turns the diagrams off. Note that this option also works with HAVE_DOT -# disabled, but it is recommended to install and use dot, since it yields more -# powerful graphs. -# The default value is: YES. - -CLASS_DIAGRAMS = NO - -# You can define message sequence charts within doxygen comments using the \msc -# command. Doxygen will then run the mscgen tool (see: -# http://www.mcternan.me.uk/mscgen/)) to produce the chart and insert it in the -# documentation. The MSCGEN_PATH tag allows you to specify the directory where -# the mscgen tool resides. If left empty the tool is assumed to be found in the -# default search path. - -MSCGEN_PATH = - -# You can include diagrams made with dia in doxygen documentation. Doxygen will -# then run dia to produce the diagram and insert it in the documentation. The -# DIA_PATH tag allows you to specify the directory where the dia binary resides. -# If left empty dia is assumed to be found in the default search path. - -DIA_PATH = - # If set to YES the inheritance and collaboration graphs will hide inheritance # and usage relations if the target is undocumented or is not a class. # The default value is: YES. @@ -2179,7 +2439,7 @@ HIDE_UNDOC_RELATIONS = YES # If you set the HAVE_DOT tag to YES then doxygen will assume the dot tool is # available from the path. This tool is part of Graphviz (see: -# http://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent +# https://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent # Bell Labs. The other options in this section have no effect if this option is # set to NO # The default value is: NO. @@ -2196,35 +2456,52 @@ HAVE_DOT = NO DOT_NUM_THREADS = 0 -# When you want a differently looking font in the dot files that doxygen -# generates you can specify the font name using DOT_FONTNAME. You need to make -# sure dot is able to find the font, which can be done by putting it in a -# standard location or by setting the DOTFONTPATH environment variable or by -# setting DOT_FONTPATH to the directory containing the font. -# The default value is: Helvetica. +# DOT_COMMON_ATTR is common attributes for nodes, edges and labels of +# subgraphs. When you want a differently looking font in the dot files that +# doxygen generates you can specify fontname, fontcolor and fontsize attributes. +# For details please see Node, +# Edge and Graph Attributes specification You need to make sure dot is able +# to find the font, which can be done by putting it in a standard location or by +# setting the DOTFONTPATH environment variable or by setting DOT_FONTPATH to the +# directory containing the font. Default graphviz fontsize is 14. +# The default value is: fontname=Helvetica,fontsize=10. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_FONTNAME = Helvetica +DOT_COMMON_ATTR = "fontname=Helvetica,fontsize=10" -# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of -# dot graphs. -# Minimum value: 4, maximum value: 24, default value: 10. +# DOT_EDGE_ATTR is concatenated with DOT_COMMON_ATTR. For elegant style you can +# add 'arrowhead=open, arrowtail=open, arrowsize=0.5'. Complete documentation about +# arrows shapes. +# The default value is: labelfontname=Helvetica,labelfontsize=10. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_FONTSIZE = 10 +DOT_EDGE_ATTR = "labelfontname=Helvetica,labelfontsize=10" -# By default doxygen will tell dot to use the default font as specified with -# DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set -# the path where dot can find it using this tag. +# DOT_NODE_ATTR is concatenated with DOT_COMMON_ATTR. For view without boxes +# around nodes set 'shape=plain' or 'shape=plaintext' Shapes specification +# The default value is: shape=box,height=0.2,width=0.4. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_NODE_ATTR = "shape=box,height=0.2,width=0.4" + +# You can set the path where dot can find font specified with fontname in +# DOT_COMMON_ATTR and others dot attributes. # This tag requires that the tag HAVE_DOT is set to YES. DOT_FONTPATH = -# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for -# each documented class showing the direct and indirect inheritance relations. -# Setting this tag to YES will force the CLASS_DIAGRAMS tag to NO. +# If the CLASS_GRAPH tag is set to YES or GRAPH or BUILTIN then doxygen will +# generate a graph for each documented class showing the direct and indirect +# inheritance relations. In case the CLASS_GRAPH tag is set to YES or GRAPH and +# HAVE_DOT is enabled as well, then dot will be used to draw the graph. In case +# the CLASS_GRAPH tag is set to YES and HAVE_DOT is disabled or if the +# CLASS_GRAPH tag is set to BUILTIN, then the built-in generator will be used. +# If the CLASS_GRAPH tag is set to TEXT the direct and indirect inheritance +# relations will be shown as texts / links. +# Possible values are: NO, YES, TEXT, GRAPH and BUILTIN. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. CLASS_GRAPH = YES @@ -2238,7 +2515,8 @@ CLASS_GRAPH = YES COLLABORATION_GRAPH = YES # If the GROUP_GRAPHS tag is set to YES then doxygen will generate a graph for -# groups, showing the direct groups dependencies. +# groups, showing the direct groups dependencies. See also the chapter Grouping +# in the manual. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. @@ -2261,10 +2539,32 @@ UML_LOOK = NO # but if the number exceeds 15, the total amount of fields shown is limited to # 10. # Minimum value: 0, maximum value: 100, default value: 10. -# This tag requires that the tag HAVE_DOT is set to YES. +# This tag requires that the tag UML_LOOK is set to YES. UML_LIMIT_NUM_FIELDS = 10 +# If the DOT_UML_DETAILS tag is set to NO, doxygen will show attributes and +# methods without types and arguments in the UML graphs. If the DOT_UML_DETAILS +# tag is set to YES, doxygen will add type and arguments for attributes and +# methods in the UML graphs. If the DOT_UML_DETAILS tag is set to NONE, doxygen +# will not generate fields with class member information in the UML graphs. The +# class diagrams will look similar to the default class diagrams but using UML +# notation for the relationships. +# Possible values are: NO, YES and NONE. +# The default value is: NO. +# This tag requires that the tag UML_LOOK is set to YES. + +DOT_UML_DETAILS = NO + +# The DOT_WRAP_THRESHOLD tag can be used to set the maximum number of characters +# to display on a single line. If the actual line length exceeds this threshold +# significantly it will wrapped across multiple lines. Some heuristics are apply +# to avoid ugly line breaks. +# Minimum value: 0, maximum value: 1000, default value: 17. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_WRAP_THRESHOLD = 17 + # If the TEMPLATE_RELATIONS tag is set to YES then the inheritance and # collaboration graphs will show the relations between templates and their # instances. @@ -2331,10 +2631,17 @@ GRAPHICAL_HIERARCHY = YES DIRECTORY_GRAPH = YES +# The DIR_GRAPH_MAX_DEPTH tag can be used to limit the maximum number of levels +# of child directories generated in directory dependency graphs by dot. +# Minimum value: 1, maximum value: 25, default value: 1. +# This tag requires that the tag DIRECTORY_GRAPH is set to YES. + +DIR_GRAPH_MAX_DEPTH = 1 + # The DOT_IMAGE_FORMAT tag can be used to set the image format of the images # generated by dot. For an explanation of the image formats see the section # output formats in the documentation of the dot tool (Graphviz (see: -# http://www.graphviz.org/)). +# https://www.graphviz.org/)). # Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order # to make the SVG files visible in IE 9+ (other browsers do not have this # requirement). @@ -2344,7 +2651,7 @@ DIRECTORY_GRAPH = YES # The default value is: png. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_IMAGE_FORMAT = png +DOT_IMAGE_FORMAT = svg # If DOT_IMAGE_FORMAT is set to svg, then this option can be set to YES to # enable generation of interactive SVG images that allow zooming and panning. @@ -2356,7 +2663,7 @@ DOT_IMAGE_FORMAT = png # The default value is: NO. # This tag requires that the tag HAVE_DOT is set to YES. -INTERACTIVE_SVG = NO +INTERACTIVE_SVG = YES # The DOT_PATH tag can be used to specify the path where the dot tool can be # found. If left blank, it is assumed the dot tool can be found in the path. @@ -2371,11 +2678,12 @@ DOT_PATH = DOTFILE_DIRS = -# The MSCFILE_DIRS tag can be used to specify one or more directories that -# contain msc files that are included in the documentation (see the \mscfile -# command). +# You can include diagrams made with dia in doxygen documentation. Doxygen will +# then run dia to produce the diagram and insert it in the documentation. The +# DIA_PATH tag allows you to specify the directory where the dia binary resides. +# If left empty dia is assumed to be found in the default search path. -MSCFILE_DIRS = +DIA_PATH = # The DIAFILE_DIRS tag can be used to specify one or more directories that # contain dia files that are included in the documentation (see the \diafile @@ -2384,13 +2692,18 @@ MSCFILE_DIRS = DIAFILE_DIRS = # When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the -# path where java can find the plantuml.jar file. If left blank, it is assumed -# PlantUML is not used or called during a preprocessing step. Doxygen will -# generate a warning when it encounters a \startuml command in this case and -# will not generate output for the diagram. +# path where java can find the plantuml.jar file or to the filename of jar file +# to be used. If left blank, it is assumed PlantUML is not used or called during +# a preprocessing step. Doxygen will generate a warning when it encounters a +# \startuml command in this case and will not generate output for the diagram. PLANTUML_JAR_PATH = +# When using plantuml, the PLANTUML_CFG_FILE tag can be used to specify a +# configuration file for plantuml. + +PLANTUML_CFG_FILE = + # When using plantuml, the specified paths are searched for files specified by # the !include statement in a plantuml block. @@ -2420,18 +2733,6 @@ DOT_GRAPH_MAX_NODES = 50 MAX_DOT_GRAPH_DEPTH = 0 -# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent -# background. This is disabled by default, because dot on Windows does not seem -# to support this out of the box. -# -# Warning: Depending on the platform used, enabling this option may lead to -# badly anti-aliased labels on the edges of a graph (i.e. they become hard to -# read). -# The default value is: NO. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_TRANSPARENT = NO - # Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output # files in one run (i.e. multiple -o and -T options on the command line). This # makes dot run faster, but since only newer versions of dot (>1.8.10) support @@ -2444,14 +2745,34 @@ DOT_MULTI_TARGETS = NO # If the GENERATE_LEGEND tag is set to YES doxygen will generate a legend page # explaining the meaning of the various boxes and arrows in the dot generated # graphs. +# Note: This tag requires that UML_LOOK isn't set, i.e. the doxygen internal +# graphical representation for inheritance and collaboration diagrams is used. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. GENERATE_LEGEND = YES -# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate dot +# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate # files that are used to generate the various graphs. +# +# Note: This setting is not only used for dot files but also for msc temporary +# files. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. DOT_CLEANUP = YES + +# You can define message sequence charts within doxygen comments using the \msc +# command. If the MSCGEN_TOOL tag is left empty (the default), then doxygen will +# use a built-in version of mscgen tool to produce the charts. Alternatively, +# the MSCGEN_TOOL tag can also specify the name an external tool. For instance, +# specifying prog as the value, doxygen will call the tool as prog -T +# -o . The external tool should support +# output file formats "png", "eps", "svg", and "ismap". + +MSCGEN_TOOL = + +# The MSCFILE_DIRS tag can be used to specify one or more directories that +# contain msc files that are included in the documentation (see the \mscfile +# command). + +MSCFILE_DIRS = diff --git a/docs/index.rst b/docs/index.rst index 30ef672f84..89a5e3e836 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,31 +8,38 @@ Composable Kernel User Guide ******************************************************************** -The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. This document contains instructions for installing, using, and contributing to the Composable Kernel project. To learn more see :ref:`what-is-ck`. +The Composable Kernel library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages such as `HIP C++ `_. -The CK documentation is structured as follows: +The Composable Kernel repository is located at `https://github.com/ROCm/composable_kernel `_. .. grid:: 2 :gutter: 3 - .. grid-item-card:: Installation + .. grid-item-card:: Install - * :ref:`docker-hub` + * :doc:`Composable Kernel prerequisites <./install/Composable-Kernel-prerequisites>` + * :doc:`Build and install Composable Kernel <./install/Composable-Kernel-install>` + * :doc:`Build and install Composable Kernel on a Docker image <./install/Composable-Kernel-Docker>` .. grid-item-card:: Conceptual - * :ref:`what-is-ck` + * :doc:`Composable Kernel structure <./conceptual/Composable-Kernel-structure>` + * :doc:`Composable Kernel mathematical basis <./conceptual/Composable-Kernel-math>` - .. grid-item-card:: API reference + .. grid-item-card:: Tutorials - * :ref:`supported-primitives` - * :ref:`api-reference` - * :ref:`wrapper` + * :doc:`Composable Kernel examples and tests <./tutorial/Composable-Kernel-examples>` - .. grid-item-card:: Tutorial - - * :ref:`hello-world` + .. grid-item-card:: Reference + * :doc:`Composable Kernel supported scalar types <./reference/Composable_Kernel_supported_scalar_types>` + * :doc:`Composable Kernel custom types <./reference/Composable_Kernel_custom_types>` + * :doc:`Composable Kernel vector utilities <./reference/Composable_Kernel_vector_utilities>` + * :ref:`wrapper` + * :doc:`Composable Kernel API reference <./doxygen/html/namespace_c_k>` + * :doc:`CK Tile API reference <./doxygen/html/namespaceck__tile>` + * :doc:`Composable Kernel complete API class list <./doxygen/html/annotated>` + To contribute to the documentation refer to `Contributing to ROCm `_. You can find licensing information on the `Licensing `_ page. diff --git a/docs/install/Composable-Kernel-Docker.rst b/docs/install/Composable-Kernel-Docker.rst new file mode 100644 index 0000000000..d40cc2bff5 --- /dev/null +++ b/docs/install/Composable-Kernel-Docker.rst @@ -0,0 +1,16 @@ +.. meta:: + :description: Composable Kernel docker files + :keywords: composable kernel, CK, ROCm, API, docker + +.. _docker-hub: + +******************************************************************** +Composable Kernel Docker containers +******************************************************************** + +Docker images that include all the required prerequisites for building Composable Kernel are available on `Docker Hub `_. + +The images also contain `ROCm `_, `CMake `_, and the `ROCm LLVM compiler infrastructure `_. + +Composable Kernel Docker images are named according to their operating system and ROCm version. For example, a Docker image named ``ck_ub22.04_rocm6.3`` would correspond to an Ubuntu 22.04 image with ROCm 6.3. + diff --git a/docs/install/Composable-Kernel-install.rst b/docs/install/Composable-Kernel-install.rst new file mode 100644 index 0000000000..61b1fe0fcb --- /dev/null +++ b/docs/install/Composable-Kernel-install.rst @@ -0,0 +1,72 @@ +.. meta:: + :description: Composable Kernel build and install + :keywords: composable kernel, CK, ROCm, API, documentation, install + +****************************************************** +Building and installing Composable Kernel with CMake +****************************************************** + +Before you begin, clone the `Composable Kernel GitHub repository `_ and create a ``build`` directory in its root: + +.. code:: shell + + git clone https://github.com/ROCm/composable_kernel.git + cd composable_kernel + mkdir build + +Change directory to the ``build`` directory and generate the makefile using the ``cmake`` command. Two build options are required: + +* ``CMAKE_PREFIX_PATH``: The ROCm installation path. ROCm is installed in ``/opt/rocm`` by default. +* ``CMAKE_CXX_COMPILER``: The path to the Clang compiler. Clang is found at ``/opt/rocm/llvm/bin/clang++`` by default. + + +.. code:: shell + + cd build + cmake ../. -D CMAKE_PREFIX_PATH="/opt/rocm" -D CMAKE_CXX_COMPILER="/opt/rocm/llvm/bin/clang++" [-D [-D] ...] + + +Other build options are: + +* ``DISABLE_DL_KERNELS``: Set this to "ON" to not build deep learning (DL) and data parallel primitive (DPP) instances. + + .. note:: + + DL and DPP instances are useful on architectures that don't support XDL or WMMA. + +* ``CK_USE_FP8_ON_UNSUPPORTED_ARCH``: Set to ``ON`` to build FP8 data type instances on gfx90a without native FP8 support. +* ``GPU_TARGETS``: Target architectures. Target architectures in this list must all be different versions of the same architectures. Enclose the list of targets in quotation marks. Separate multiple targets with semicolons (``;``). For example, ``cmake -D GPU_TARGETS="gfx908;gfx90a"``. This option is required to build tests and examples. +* ``GPU_ARCHS``: Target architectures. Target architectures in this list are not limited to different versions of the same architectures. Enclose the list of targets in quotation marks. Separate multiple targets with semicolons (``;``). For example, ``cmake -D GPU_TARGETS="gfx908;gfx1100"``. +* ``CMAKE_BUILD_TYPE``: The build type. Can be ``None``, ``Release``, ``Debug``, ``RelWithDebInfo``, or ``MinSizeRel``. CMake will use ``Release`` by default. + +.. Note:: + + If neither ``GPU_TARGETS`` nor ``GPU_ARCHS`` is specified, Composable Kernel will be built for all targets supported by the compiler. + +Build Composable Kernel using the generated makefile. This will build the library, the examples, and the tests, and save them to ``bin``. + +.. code:: shell + + make -j20 + +The ``-j`` option speeds up the build by using multiple threads in parallel. For example, ``-j20`` uses twenty threads in parallel. On average, each thread will use 2GB of memory. Make sure that the number of threads you use doesn't exceed the available memory in your system. + +Using ``-j`` alone will launch an unlimited number of threads and is not recommended. + +Install the Composable Kernel library: + +.. code:: shell + + make install + +After running ``make install``, the Composable Kernel files will be saved to the following locations: + +* Library files: ``/opt/rocm/lib/`` +* Header files: ``/opt/rocm/include/ck/`` and ``/opt/rocm/include/ck_tile/`` +* Examples, tests, and ckProfiler: ``/opt/rocm/bin/`` + +For information about ckProfiler, see `the ckProfiler readme file `_. + +For information about running the examples and tests, see :doc:`Composable Kernel examples and tests <../tutorial/Composable-Kernel-examples>`. + + diff --git a/docs/install/Composable-Kernel-prerequisites.rst b/docs/install/Composable-Kernel-prerequisites.rst new file mode 100644 index 0000000000..10be849ea6 --- /dev/null +++ b/docs/install/Composable-Kernel-prerequisites.rst @@ -0,0 +1,32 @@ +.. meta:: + :description: Composable Kernel prerequisites + :keywords: composable kernel, CK, ROCm, API, documentation, prerequisites + +****************************************************** +Composable Kernel prerequisites +****************************************************** + +Docker images that include all the required prerequisites for building Composable Kernel are available on `Docker Hub `_. + +The following prerequisites are required to build and install Composable Kernel: + +* cmake +* hip-rocclr +* iputils-ping +* jq +* libelf-dev +* libncurses5-dev +* libnuma-dev +* libpthread-stubs0-dev +* llvm-amdgpu +* mpich +* net-tools +* python3 +* python3-dev +* python3-pip +* redis +* rocm-llvm-dev +* zlib1g-dev +* libzstd-dev +* openssh-server +* clang-format-12 diff --git a/docs/install/dockerhub.rst b/docs/install/dockerhub.rst deleted file mode 100644 index 87eb5a4f81..0000000000 --- a/docs/install/dockerhub.rst +++ /dev/null @@ -1,101 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _docker-hub: - -******************************************************************** -CK Docker Hub -******************************************************************** - -Why do I need this? -=================== - -To make things simpler, and bring Composable Kernel and its dependencies together, -docker images can be found on `Docker Hub `_. Docker images provide a complete image of the OS, the Composable Kernel library, and its dependencies in a single downloadable file. - -Refer to `Docker Overview `_ for more information on Docker images and containers. - -Which image is right for me? -============================ - -The image naming includes information related to the docker image. -For example ``ck_ub20.04_rocm6.0`` indicates the following: - -* ``ck`` - made for running Composable Kernel; -* ``ub20.04`` - based on Ubuntu 20.04; -* ``rocm6.0`` - ROCm platform version 6.0. - -Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. Use the ``docker pull`` command to download the file:: - - docker pull rocm/composable_kernel:ck_ub20.04_rocm6.0 - - -What is inside the image? -------------------------- - -The docker images have everything you need for running CK including: - -* `ROCm `_ -* `CMake `_ -* `Compiler `_ -* `Composable Kernel library `_ - -Running the docker container -============================ - -After downloading the docker image, you can start the container using one of a number of commands. Start with the ``docker run`` command as shown below:: - - docker run \ - -it \ - --privileged \ - --group-add sudo \ - -w /root/workspace \ - -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm6.0 \ - /bin/bash - -After starting the bash shell, the docker container current folder is `~/workspace`. The library path is ``~/workspace/composable_kernel``. Navigate to the library to begin the tutorial as explained in :ref:`hello-world`: - -.. note:: - - If your current folder is different from `${HOME}`, adjust the line ``-v ${HOME}:/root/workspace`` in the ``docker run`` command to fit your folder structure. - -Stop and restart the docker image -================================= - -After finishing the tutorial, or just when you have completed your work session, you can close the docker container, or stop the docker container to restart it at another time. Closing the docker container means that it is still in the active state, and can be resumed from where you left it. Stopping the container closes it, and returns the image to its initial state. - -Use the ``Ctrl-D`` option to exit the container, while leaving it active, so you can return to the container in its current state to resume the tutorial, or pickup your project where you left off. - -To restart the active container use the ``docker exec`` command to specify the container name and options as follows:: - - docker exec -it bash - -Where: - -* `exec` is the docker command -* `-it` is the interactive option for `exec` -* `` specifies an active container on the system -* `bash` specifies the command to run in the interactive shell - -.. note:: - - You can use the ``docker container ls`` command to list the active containers on the system. - -To start a container from the image, use the ``docker start`` command:: - - docker start - -Then use the docker exec command as shown above to start the bash shell. - -Use the ``docker stop`` command to stop the container and restore the image to its initial state:: - - docker stop - -Editing the docker image -======================= - -If you want to customize the docker image, edit the -`Dockerfile `_ -from the GitHub repository to suit your needs. diff --git a/docs/reference/API_Reference_Guide.rst b/docs/reference/API_Reference_Guide.rst deleted file mode 100644 index 22222b0cf0..0000000000 --- a/docs/reference/API_Reference_Guide.rst +++ /dev/null @@ -1,54 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _api-reference: - -******************************************************************** -API reference guide -******************************************************************** - - -This document contains details of the APIs for the Composable Kernel (CK) library and introduces -some of the key design principles that are used to write new classes that extend CK functionality. - -================= -Using CK API -================= - -This section describes how to use the CK library API. - -================= -CK Datatypes -================= - ------------------ -DeviceMem ------------------ - -.. doxygenstruct:: DeviceMem - ---------------------------- -Kernels For Flashattention ---------------------------- - -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This section lists -the classes that are used in the CK GPU implementation of Flashattention. - -**Gridwise classes** - -.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle - -**Blockwise classes** - -.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1 - -.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2 - -.. doxygenstruct:: ck::BlockwiseSoftmax - -**Threadwise classes** - -.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic - -.. bibliography:: diff --git a/docs/reference/wrapper.rst b/docs/reference/Composable-Kernel-wrapper.rst similarity index 88% rename from docs/reference/wrapper.rst rename to docs/reference/Composable-Kernel-wrapper.rst index 190fbcd445..4baa8d2b64 100644 --- a/docs/reference/wrapper.rst +++ b/docs/reference/Composable-Kernel-wrapper.rst @@ -1,20 +1,15 @@ .. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation + :description: Composable Kernel wrapper + :keywords: composable kernel, CK, ROCm, API, wrapper .. _wrapper: ******************************************************************** -Wrapper +Composable Kernel wrapper ******************************************************************** -------------------------------------- -Description -------------------------------------- - -The CK library provides a lightweight wrapper for more complex operations implemented in -the library. +The Composable Kernel library provides a lightweight wrapper to simplify the more complex operations. Example: diff --git a/docs/reference/Composable_Kernel_custom_types.rst b/docs/reference/Composable_Kernel_custom_types.rst new file mode 100644 index 0000000000..863d4131b9 --- /dev/null +++ b/docs/reference/Composable_Kernel_custom_types.rst @@ -0,0 +1,39 @@ +.. meta:: + :description: Composable Kernel supported custom types + :keywords: composable kernel, custom, data types, support, CK, ROCm + +****************************************************** +Composable Kernel custom data types +****************************************************** + +Composable Kernel supports the use of custom types that provide a way to implement specialized numerical formats. + +To use custom types, a C++ type that implements the necessary operations for tensor computations needs to be created. These should include: + +* Constructors and initialization methods +* Arithmetic operators if the type will be used in computational operations +* Any conversion functions needed to interface with other parts of an application + +For example, to create a complex half-precision type: + +.. code:: cpp + + struct complex_half_t + { + half_t real; + half_t img; + }; + + struct complex_half_t + { + using type = half_t; + type real; + type img; + + complex_half_t() : real{type{}}, img{type{}} {} + complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {} + }; + +Custom types can be particularly useful for specialized applications such as complex number arithmetic, +custom quantization schemes, or domain-specific number representations. + diff --git a/docs/reference/Composable_Kernel_supported_scalar_types.rst b/docs/reference/Composable_Kernel_supported_scalar_types.rst new file mode 100644 index 0000000000..7ea1a9eaeb --- /dev/null +++ b/docs/reference/Composable_Kernel_supported_scalar_types.rst @@ -0,0 +1,69 @@ +.. meta:: + :description: Composable Kernel supported scalar types + :keywords: composable kernel, scalar, data types, support, CK, ROCm + +*************************************************** +Composable Kernel supported scalar data types +*************************************************** + +The Composable Kernel library provides support for the following scalar data types: + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Type + - Bit Width + - Description + + * - ``double`` + - 64-bit + - Standard IEEE 754 double precision floating point + + * - ``float`` + - 32-bit + - Standard IEEE 754 single precision floating point + + * - ``int32_t`` + - 32-bit + - Standard signed 32-bit integer + + * - ``int8_t`` + - 8-bit + - Standard signed 8-bit integer + + * - ``uint8_t`` + - 8-bit + - Standard unsigned 8-bit integer + + * - ``bool`` + - 1-bit + - Boolean type + + * - ``ck::half_t`` + - 16-bit + - IEEE 754 half precision floating point with 5 exponent bits, 10 mantissa bits, and 1 sign bit + + * - ``ck::bhalf_t`` + - 16-bit + - Brain floating point with 8 exponent bits, 7 mantissa bits, and 1 sign bit + + * - ``ck::f8_t`` + - 8-bit + - 8-bit floating point (E4M3 format) with 4 exponent bits, 3 mantissa bits, and 1 sign bit + + * - ``ck::bf8_t`` + - 8-bit + - 8-bit brain floating point (E5M2 format) with 5 exponent bits, 2 mantissa bits, and 1 sign bit + + * - ``ck::f4_t`` + - 4-bit + - 4-bit floating point format (E2M1 format) with 2 exponent bits, 1 mantissa bit, and 1 sign bit + + * - ``ck::f6_t`` + - 6-bit + - 6-bit floating point format (E2M3 format) with 2 exponent bits, 3 mantissa bits, and 1 sign bit + + * - ``ck::bf6_t`` + - 6-bit + - 6-bit brain floating point format (E3M2 format) with 3 exponent bits, 2 mantissa bits, and 1 sign bit \ No newline at end of file diff --git a/docs/reference/Composable_Kernel_vector_utilities.rst b/docs/reference/Composable_Kernel_vector_utilities.rst new file mode 100644 index 0000000000..3103653191 --- /dev/null +++ b/docs/reference/Composable_Kernel_vector_utilities.rst @@ -0,0 +1,16 @@ +.. meta:: + :description: Composable Kernel supported precision types and custom type support + :keywords: composable kernel, precision, data types, ROCm + +****************************************************** +Composable Kernel vector template utilities +****************************************************** + +Composable Kernel includes template utilities for creating vector types with customizable widths. These template utilities also flatten nested vector types into a single, wider vector, preventing the creation of vectors of vectors. + +Vectors composed of supported scalar and custom types can be created with the ``ck::vector_type`` template. + +For example, ``ck::vector_type`` creates a vector composed of four floats and ``ck::vector_type`` creates a vector composed of eight half-precision scalars. + +For vector operations to be valid, the underlying types must be either a :doc:`supported scalar type ` or :doc:`a custom type ` that implements the required operations. + diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 533b81cd39..2ef3383d84 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -3,34 +3,43 @@ defaults: root: index subtrees: -- caption: Conceptual - entries: - - file: conceptual/what-is-ck.rst - title: What is Composable Kernel? - - caption: Install entries: - - file: install/dockerhub.rst - title: Docker Hub + - file: install/Composable-Kernel-prerequisites.rst + title: Composable Kernel prerequisites + - file: install/Composable-Kernel-install.rst + title: Build and install Composable Kernel + - file: install/Composable-Kernel-Docker.rst + title: Composable Kernel Docker images -- caption: CK API Reference +- caption: Conceptual entries: - - file: reference/Supported_Primitives_Guide.rst - title: Supported Primitives - - file: reference/API_Reference_Guide.rst - title: API Reference - - file: reference/wrapper.rst - title: Wrapper + - file: conceptual/Composable-Kernel-structure.rst + title: Composable Kernel structure + - file: conceptual/Composable-Kernel-math.rst + title: Composable Kernel mathematical basis - caption: Tutorial entries: - - file: tutorial/tutorial_hello_world.rst - title: Hello World Tutorial + - file: tutorial/Composable-Kernel-examples.rst + title: Composable Kernel examples + +- caption: Reference + entries: + - file: reference/Composable_Kernel_supported_scalar_types.rst + title: Composable Kernel scalar types + - file: reference/Composable_Kernel_custom_types.rst + title: Composable Kernel custom types + - file: reference/Composable_Kernel_vector_utilities.rst + title: Composable Kernel vector utilities + - file: reference/Composable-Kernel-wrapper.rst + title: Composable Kernel wrapper + - file: doxygen/html/annotated.rst + title: Composable Kernel class list - caption: About entries: - file: Contributors_Guide.rst - title: Contributing to CK + title: Contributing to Composable Kernel - file: license.rst title: License - \ No newline at end of file diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 246439bcdb..beedb4e867 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.7.2 -sphinxcontrib-bibtex==2.6.2 +rocm-docs-core[api_reference]==1.20.1 +sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 185acda5b2..e8aa02aa01 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -6,71 +6,183 @@ # accessible-pygments==0.0.5 # via pydata-sphinx-theme -alabaster==0.7.16 +alabaster==1.0.0 # via sphinx -babel==2.15.0 +asttokens==3.0.0 + # via stack-data +attrs==25.3.0 + # via + # jsonschema + # jupyter-cache + # referencing +babel==2.17.0 # via # pydata-sphinx-theme # sphinx -beautifulsoup4==4.12.3 +beautifulsoup4==4.13.4 # via pydata-sphinx-theme -breathe==4.35.0 +breathe==4.36.0 # via rocm-docs-core -certifi==2024.7.4 +certifi==2025.1.31 # via requests -cffi==1.16.0 +cffi==1.17.1 # via # cryptography # pynacl -charset-normalizer==3.3.2 +charset-normalizer==3.4.1 # via requests -click==8.1.7 - # via sphinx-external-toc -cryptography==43.0.0 +click==8.1.8 + # via + # click-log + # doxysphinx + # jupyter-cache + # sphinx-external-toc +click-log==0.4.0 + # via doxysphinx +comm==0.2.2 + # via ipykernel +contourpy==1.3.2 + # via matplotlib +cryptography==44.0.2 # via pyjwt -deprecated==1.2.14 +cycler==0.12.1 + # via matplotlib +debugpy==1.8.14 + # via ipykernel +decorator==5.2.1 + # via ipython +deprecated==1.2.18 # via pygithub docutils==0.21.2 # via - # breathe # myst-parser # pybtex-docutils # pydata-sphinx-theme # sphinx # sphinxcontrib-bibtex -fastjsonschema==2.20.0 +doxysphinx==3.3.12 # via rocm-docs-core -gitdb==4.0.11 +exceptiongroup==1.2.2 + # via ipython +executing==2.2.0 + # via stack-data +fastjsonschema==2.21.1 + # via + # nbformat + # rocm-docs-core +fonttools==4.57.0 + # via matplotlib +gitdb==4.0.12 # via gitpython -gitpython==3.1.43 +gitpython==3.1.44 # via rocm-docs-core -idna==3.7 +greenlet==3.2.1 + # via sqlalchemy +idna==3.10 # via requests imagesize==1.4.1 # via sphinx -jinja2==3.1.4 +importlib-metadata==8.6.1 + # via + # jupyter-cache + # myst-nb +ipykernel==6.29.5 + # via myst-nb +ipython==8.35.0 + # via + # ipykernel + # myst-nb +jedi==0.19.2 + # via ipython +jinja2==3.1.6 # via # myst-parser # sphinx +jsonschema==4.23.0 + # via nbformat +jsonschema-specifications==2024.10.1 + # via jsonschema +jupyter-cache==1.0.1 + # via myst-nb +jupyter-client==8.6.3 + # via + # ipykernel + # nbclient +jupyter-core==5.7.2 + # via + # ipykernel + # jupyter-client + # nbclient + # nbformat +kiwisolver==1.4.8 + # via matplotlib latexcodec==3.0.0 # via pybtex +libsass==0.22.0 + # via doxysphinx +lxml==5.2.1 + # via doxysphinx markdown-it-py==3.0.0 # via # mdit-py-plugins # myst-parser -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 -mdit-py-plugins==0.4.1 +matplotlib==3.10.1 + # via doxysphinx +matplotlib-inline==0.1.7 + # via + # ipykernel + # ipython +mdit-py-plugins==0.4.2 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-parser==3.0.1 +mpire==2.10.2 + # via doxysphinx +myst-nb==1.2.0 # via rocm-docs-core -packaging==24.1 +myst-parser==4.0.1 + # via myst-nb +nbclient==0.10.2 # via + # jupyter-cache + # myst-nb +nbformat==5.10.4 + # via + # jupyter-cache + # myst-nb + # nbclient +nest-asyncio==1.6.0 + # via ipykernel +numpy==1.26.4 + # via + # contourpy + # doxysphinx + # matplotlib +packaging==25.0 + # via + # ipykernel + # matplotlib # pydata-sphinx-theme # sphinx -pybtex==0.24.0 +parso==0.8.4 + # via jedi +pexpect==4.9.0 + # via ipython +pillow==11.2.1 + # via matplotlib +platformdirs==4.3.7 + # via jupyter-core +prompt-toolkit==3.0.51 + # via ipython +psutil==7.0.0 + # via ipykernel +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data +pybtex==0.25.1 # via # pybtex-docutils # sphinxcontrib-bibtex @@ -82,40 +194,67 @@ pydata-sphinx-theme==0.15.4 # via # rocm-docs-core # sphinx-book-theme -pygithub==2.3.0 +pygithub==2.6.1 # via rocm-docs-core -pygments==2.18.0 +pygments==2.19.1 # via # accessible-pygments + # ipython + # mpire # pydata-sphinx-theme # sphinx -pyjwt[crypto]==2.8.0 +pyjson5==1.6.8 + # via doxysphinx +pyjwt[crypto]==2.10.1 # via pygithub pynacl==1.5.0 # via pygithub -pyyaml==6.0.1 +pyparsing==3.2.3 # via + # doxysphinx + # matplotlib +python-dateutil==2.9.0.post0 + # via + # jupyter-client + # matplotlib +pyyaml==6.0.2 + # via + # jupyter-cache + # myst-nb # myst-parser # pybtex # rocm-docs-core # sphinx-external-toc +pyzmq==26.4.0 + # via + # ipykernel + # jupyter-client +referencing==0.36.2 + # via + # jsonschema + # jsonschema-specifications requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.7.2 +rocm-docs-core[api-reference]==1.20.1 # via -r requirements.in -six==1.16.0 - # via pybtex -smmap==5.0.1 +rpds-py==0.24.0 + # via + # jsonschema + # referencing +six==1.17.0 + # via python-dateutil +smmap==5.0.2 # via gitdb snowballstemmer==2.2.0 # via sphinx -soupsieve==2.5 +soupsieve==2.7 # via beautifulsoup4 -sphinx==7.4.7 +sphinx==8.1.3 # via # breathe + # myst-nb # myst-parser # pydata-sphinx-theme # rocm-docs-core @@ -125,19 +264,19 @@ sphinx==7.4.7 # sphinx-external-toc # sphinx-notfound-page # sphinxcontrib-bibtex -sphinx-book-theme==1.1.3 +sphinx-book-theme==1.1.4 # via rocm-docs-core sphinx-copybutton==0.5.2 # via rocm-docs-core -sphinx-design==0.6.0 +sphinx-design==0.6.1 # via rocm-docs-core sphinx-external-toc==1.0.1 # via rocm-docs-core -sphinx-notfound-page==1.0.3 +sphinx-notfound-page==1.1.0 # via rocm-docs-core sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-bibtex==2.6.2 +sphinxcontrib-bibtex==2.6.5 # via -r requirements.in sphinxcontrib-devhelp==2.0.0 # via sphinx @@ -149,15 +288,46 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx -tomli==2.0.1 +sqlalchemy==2.0.40 + # via jupyter-cache +stack-data==0.6.3 + # via ipython +tabulate==0.9.0 + # via jupyter-cache +tomli==2.2.1 # via sphinx -typing-extensions==4.12.2 +tornado==6.4.2 # via + # ipykernel + # jupyter-client +tqdm==4.67.1 + # via mpire +traitlets==5.14.3 + # via + # comm + # ipykernel + # ipython + # jupyter-client + # jupyter-core + # matplotlib-inline + # nbclient + # nbformat +typing-extensions==4.13.2 + # via + # beautifulsoup4 + # ipython + # myst-nb # pydata-sphinx-theme # pygithub -urllib3==2.2.2 + # referencing + # sqlalchemy +urllib3==2.4.0 # via # pygithub # requests -wrapt==1.16.0 +wcwidth==0.2.13 + # via prompt-toolkit +wrapt==1.17.2 # via deprecated +zipp==3.21.0 + # via importlib-metadata diff --git a/docs/tutorial/Composable-Kernel-examples.rst b/docs/tutorial/Composable-Kernel-examples.rst new file mode 100644 index 0000000000..62422d6f15 --- /dev/null +++ b/docs/tutorial/Composable-Kernel-examples.rst @@ -0,0 +1,40 @@ +.. meta:: + :description: Composable Kernel examples and tests + :keywords: composable kernel, CK, ROCm, API, examples, tests + +******************************************************************** +Composable Kernel examples and tests +******************************************************************** + +After :doc:`building and installing Composable Kernel <../install/Composable-Kernel-install>`, the examples and tests will be moved to ``/opt/rocm/bin/``. + +All tests have the prefix ``test`` and all examples have the prefix ``example``. + +Use ``ctest`` with no arguments to run all examples and tests, or use ``ctest -R`` to run a single test. For example: + +.. code:: shell + + ctest -R test_gemm_fp16 + +Examples can be run individually as well. For example: + +.. code:: shell + + ./bin/example_gemm_xdl_fp16 1 1 1 + +For instructions on how to run individual examples and tests, see their README files in the |example|_ and |test|_ GitHub folders. + +To run smoke tests, use ``make smoke``. + +To run regression tests, use ``make regression``. + +In general, tests that run for under thirty seconds are included in the smoke tests and tests that run for over thirty seconds are included in the regression tests. + +.. |example| replace:: ``example`` +.. _example: https://github.com/ROCm/composable_kernel/tree/develop/example + +.. |client_example| replace:: ``client_example`` +.. _client_example: https://github.com/ROCm/composable_kernel/tree/develop/client_example + +.. |test| replace:: ``test`` +.. _test: https://github.com/ROCm/composable_kernel/tree/develop/test \ No newline at end of file diff --git a/docs/tutorial/tutorial_hello_world.rst b/docs/tutorial/tutorial_hello_world.rst deleted file mode 100644 index c31460785b..0000000000 --- a/docs/tutorial/tutorial_hello_world.rst +++ /dev/null @@ -1,165 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _hello-world: - -******************************************************************** -Hello World Tutorial -******************************************************************** - -This tutorial is for engineers dealing with artificial intelligence and machine learning who -would like to optimize pipelines and improve performance using the Composable -Kernel (CK) library. This tutorial provides an introduction to the CK library. You will build the library and run some examples using a "Hello World" example. - -Description -=========== - -Modern AI technology solves more and more problems in a variety of fields, but crafting fast and -efficient workflows is still challenging. CK can make the AI workflow fast -and efficient. CK is a collection of optimized AI operator kernels with tools to create -new kernels. The library has components required for modern neural network architectures -including matrix multiplication, convolution, contraction, reduction, attention modules, a variety of activation functions, and fused operators. - -CK library acceleration features are based on: - -* Layered structure -* Tile-based computation model -* Tensor coordinate transformation -* Hardware acceleration use -* Support of low precision data types including fp16, bf16, int8 and int4 - -If you need more technical details and benchmarking results read the following -`blog post `_. - -To download the library visit the `composable_kernel repository `_. - -Hardware targets -================ - -CK library fully supports `gfx908` and `gfx90a` GPU architectures, while only some operators are -supported for `gfx1030` devices. Check your hardware to determine the target GPU architecture. - -========== ========= -GPU Target AMD GPU -========== ========= -gfx908 Radeon Instinct MI100 -gfx90a Radeon Instinct MI210, MI250, MI250X -gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT -========== ========= - -There are also `cloud options `_ you can find if -you don't have an AMD GPU at hand. - -Build the library -================= - -This tutorial is based on the use of docker images as explained in :ref:`docker-hub`. Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. - -.. note:: - - You can also `install ROCm `_ on your system, clone the `Composable Kernel repository `_ on GitHub, and use that to build and run the examples using the commands described below. - -Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library:: - - cd composable_kernel/ - -Create and change to a ``build`` directory:: - - mkdir build && cd build - -The previous section discussed supported GPU architecture. Once you decide which hardware targets are needed, run CMake using the ``GPU_TARGETS`` flag:: - - cmake \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_CXX_FLAGS="-O3" \ - -D CMAKE_BUILD_TYPE=Release \ - -D BUILD_DEV=OFF \ - -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. - -If everything goes well the CMake command will return:: - - -- Configuring done - -- Generating done - -- Build files have been written to: "/root/workspace/composable_kernel/build" - -Finally, you can build examples and tests:: - - make -j examples tests - -When complete you should see:: - - Scanning dependencies of target tests - [100%] Built target tests - -Run examples and tests -====================== - -Examples are listed as test cases as well, so you can run all examples and tests with:: - - ctest - -You can check the list of all tests by running:: - - ctest -N - -You can also run examples separately as shown in the following example execution:: - - ./bin/example_gemm_xdl_fp16 1 1 1 - -The arguments ``1 1 1`` mean that you want to run this example in the mode: verify results with CPU, initialize matrices with integers, and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. - -If you have a device based on `gfx908` or `gfx90a` architecture, and if the example runs as expected, you should see something like:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - Perf: 1.08153 ms, 119.136 TFlops, 89.1972 GB/s, DeviceGemm_Xdl_CShuffle LoopScheduler: Interwave, PipelineVersion: v1 - -However, running it on a `gfx1030` device should result in the following:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem - -Don't worry, some operators are supported on `gfx1030` architecture, so you can run a -separate example like:: - - ./bin/example_gemm_dl_fp16 1 1 1 - -and it should return something like:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} - b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} - arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} - arg.c_grid_desc_m_n_{ 3840, 4096} - launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} - Warm up 1 time - Start running 10 times... - Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> - -.. note:: - - A new CMake flag ``DL_KERNELS`` has been added to the latest versions of CK. If you do not see the above results when running ``example_gemm_dl_fp16``, you might need to add ``-D DL_KERNELS=ON`` to your CMake command to build the operators supported on the `gfx1030` architecture. - -You can also run a separate test:: - - ctest -R test_gemm_fp16 - -If everything goes well you should see something like:: - - Start 121: test_gemm_fp16 - 1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec - - 100% tests passed, 0 tests failed out of 1 - -Summary -======= - -In this tutorial you took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. In the next tutorial you will run kernels with different configurations to find out the best one for your hardware and task. - -P.S.: If you are running on a cloud instance, don't forget to switch off the cloud instance. diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 98fd9c6b77..e6a26ecafd 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -28,11 +28,36 @@ add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) + add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) + + +add_example_executable(example_gemm_xdl_fp16_fp8_streamk_v3 gemm_xdl_fp16_fp8_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3) + add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) +set(GEMM_OPTIONS) +list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-16") +example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS}) + + +list(APPEND gpu_list gfx942 gfx950) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp) + add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp) + add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp) + add_example_executable(example_gemm_xdl_fp8_pk_i4_bpreshuffle_v3 gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp) + add_example_executable(example_gemm_xdl_fp8_pk_i4_v3 gemm_xdl_fp8_pk_i4_v3.cpp) + set(target 1) + endif() +endforeach() + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) @@ -42,9 +67,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) -add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) - add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8) @@ -58,7 +80,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) @@ -67,6 +89,12 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16) + + add_example_executable(example_gemm_xdl_bf16_streamk_v3 gemm_xdl_bf16_streamk_v3.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_streamk_v3) + + add_example_executable(example_gemm_xdl_fp8_streamk_v3 gemm_xdl_fp8_streamk_v3.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_streamk_v3) set(target 1) endif() endforeach() @@ -83,3 +111,20 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) add_custom_target(example_gemm_wmma) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +add_example_executable(example_gemm_wmma_bf16 gemm_wmma_bf16.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16) +add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_int8) + +add_example_executable(example_gemm_wmma_bf16_v3 gemm_wmma_bf16_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_v3) +add_example_executable(example_gemm_wmma_bf16_pk_i4_v3 gemm_wmma_bf16_pk_i4_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_pk_i4_v3) +add_example_executable(example_gemm_wmma_fp8_v3 gemm_wmma_fp8_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3) +add_example_executable(example_gemm_wmma_fp16_v3 gemm_wmma_fp16_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_v3) +add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3) +add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 3d8f4565cb..434f549443 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -14,6 +15,8 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" @@ -21,6 +24,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp" struct ProblemSize final { @@ -28,9 +32,9 @@ struct ProblemSize final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; }; struct ProblemSizeStreamK final @@ -39,11 +43,11 @@ struct ProblemSizeStreamK final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; - ck::index_t NumSKBlocks = -1; + ck::index_t NumSKBlocks = -1; // number of stream-k blocks }; struct ProblemSizeStreamK_universal final { @@ -51,12 +55,13 @@ struct ProblemSizeStreamK_universal final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; - ck::index_t Grid_size = -1; // defaults to max occupancy - ck::index_t Streamk_sel = 1; // defaults to 1-tile SK + ck::index_t Grid_size = -1; // defaults to max occupancy + ck::index_t Streamk_sel = 1; // defaults to 1-tile SK + ck::StreamKReductionStrategy reduction_strategy = ck::StreamKReductionStrategy::Atomic; }; struct ProblemSizeSplitK final @@ -65,18 +70,19 @@ struct ProblemSizeSplitK final ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; ck::index_t KBatch = 1; }; struct ExecutionConfig final { - bool do_verification = true; - int init_method = 2; - bool time_kernel = false; + // 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU + int do_verification = 1; + int init_method = 2; + bool time_kernel = false; }; template @@ -125,11 +131,12 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl; return false; } @@ -169,18 +176,33 @@ bool parse_cmd_args(int argc, if(argc >= 11) { problem_size.Streamk_sel = std::stoi(argv[10]); - problem_size.Grid_size = std::stoi(argv[11]); + + if(argc >= 12) + { + problem_size.Grid_size = std::stoi(argv[11]); + + if(argc >= 13) + { + int reduction_strategy = std::stoi(argv[12]); + problem_size.reduction_strategy = reduction_strategy == 0 + ? ck::StreamKReductionStrategy::Atomic + : ck::StreamKReductionStrategy::Reduction; + } + } } } else { std::cerr - << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl << "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)" - << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + << std::endl + << "arg11: Grid_size(-1 for max occupancy)" << std::endl + << "arg12: Reduction strategy (0: Atomic, 1: Reduction)" << std::endl; return false; } @@ -224,13 +246,14 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" - << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl + << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; return false; } @@ -274,14 +297,119 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: KBatch" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl + << "arg10: KBatch" << std::endl; return false; } return true; } + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +float i4_to_f32_gfx9(uint8_t i4) +{ + static std::unordered_map u = {{0b1000, -0.5000f}, + {0b1001, -0.4375f}, + {0b1010, -0.3750f}, + {0b1011, -0.3125f}, + {0b1100, -0.2500f}, + {0b1101, -0.1875f}, + {0b1110, -0.1250f}, + {0b1111, -0.0625f}, + {0b0, +0.0000f}, + {0b1, +0.0625f}, + {0b10, +0.1250f}, + {0b11, +0.1875f}, + {0b100, +0.2500f}, + {0b101, +0.3125f}, + {0b110, +0.3750f}, + {0b111, +0.4375f}}; + + return u[i4]; +} diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index b5fecb9752..b9284b2783 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 212b72f2a6..1684213641 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index 1840390aa9..1e64e9a0a3 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp index 7a9e3f6186..30faf542dd 100644 --- a/example/01_gemm/gemm_dpp_fp16.cpp +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -34,6 +34,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device:: + ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_bf16.cpp b/example/01_gemm/gemm_wmma_bf16.cpp new file mode 100644 index 0000000000..a87426094f --- /dev/null +++ b/example/01_gemm/gemm_wmma_bf16.cpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 2, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_wmma_bf16_pk_i4_v3.cpp new file mode 100644 index 0000000000..69ced56c0b --- /dev/null +++ b/example/01_gemm/gemm_wmma_bf16_pk_i4_v3.cpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::pk_i4_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 32; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 128, 128, KPerBlock, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, + ADataType, ADataType, PermuteA, PermuteB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_bf16_v3.cpp b/example/01_gemm/gemm_wmma_bf16_v3.cpp new file mode 100644 index 0000000000..1dc5c5286f --- /dev/null +++ b/example/01_gemm/gemm_wmma_bf16_v3.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index f8afe8d6db..28ab878ac3 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -68,6 +68,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp16_fp8_v3.cpp new file mode 100644 index 0000000000..359d823ac2 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_fp8_v3.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16_pk_i4_v3.cpp b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3.cpp new file mode 100644 index 0000000000..ec5e48a86a --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3.cpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 32; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 128, 128, KPerBlock, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, + ADataType, ADataType, PermuteA, PermuteB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp new file mode 100644 index 0000000000..7225dba721 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 128, + 128, 64, + 64, 8, 8, + 16, 16, + 4, 2, + S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + 1, 1, S<1, 32, 1, 4>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp8_v3.cpp new file mode 100644 index 0000000000..0376820b7b --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp8_v3.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::f8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 128, + 128, 64, 64, + 8, 8, + 16, 16, + 4, 2, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 32, 1, 4>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, + ComputeTypeA, ComputeTypeB>; +// clang-format on + +using ReferenceComputeType = ck::f8_t; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return !run_gemm_splitk_example(argc, argv); +} diff --git a/example/01_gemm/gemm_wmma_int8.cpp b/example/01_gemm/gemm_wmma_int8.cpp new file mode 100644 index 0000000000..a88e42d42b --- /dev/null +++ b/example/01_gemm/gemm_wmma_int8.cpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" + +using ADataType = int8_t; +using BDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int32_t; +using CDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 2, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 2, + 2, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 3cac55ef47..6cfff30dbd 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,6 +33,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp new file mode 100644 index 0000000000..7178ad46b9 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::pk_i4_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 128; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 128, + 16, 64, + KPerBlock, 8, 32, + 16, 16, + 1, 2, + S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp deleted file mode 100644 index cc14dcb8eb..0000000000 --- a/example/01_gemm/gemm_xdl_bf16_rtn.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/utility/type_convert.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" - -using ADataType = ck::bhalf_t; -using BDataType = ck::bhalf_t; -using CDataType = ck::bhalf_t; -using AccDataType = float; -using CShuffleDataType = float; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp new file mode 100644 index 0000000000..5b56a43483 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 128, + 64, 8, 8, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 2338cdc9c1..414683ffdf 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -31,9 +31,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>; -// // clang-format on -// clang-format off using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| @@ -47,6 +45,17 @@ using DeviceGemmInstance = DeviceGemmInstance1; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_fp8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp index 979a200791..a996d034e6 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -42,6 +42,17 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp new file mode 100644 index 0000000000..bd38eb17ee --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 64, + 16, 16, + 256, 8, 16, + 16, 16, + 1, 1, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp index 2e27fc66f9..b0e36b394b 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp @@ -1,12 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" -using ADataType = ck::f8_t; -using BDataType = ck::half_t; +using ADataType = ck::half_t; +using BDataType = ck::f8_t; using AccDataType = float; using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; @@ -29,15 +29,15 @@ using DeviceGemmV2Instance = AElementOp, BElementOp, CElementOp, GemmDefault, 64, 16, 16, - 64, 16, 8, + 256, 8, 16, 16, 16, 1, 1, - S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 16, 16, 0, - S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, - ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>; + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp new file mode 100644 index 0000000000..f83d479713 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 32, + 32, 32, + 4, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp index 5b163962b9..36ac51f1da 100644 --- a/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_streamk_v3.cpp @@ -8,7 +8,7 @@ using ADataType = ck::half_t; using BDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = ck::half_t; +using CShuffleDataType = float; using CDataType = ck::half_t; using ALayout = Row; @@ -43,6 +43,17 @@ using DeviceGemmV2_Streamk_Instance = using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example_streamk_v2.inc" int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp index eba0ea9d11..ecd3b7be5d 100644 --- a/example/01_gemm/gemm_xdl_fp16_v2.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -46,6 +46,17 @@ using DeviceGemmInstance = using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v3.cpp b/example/01_gemm/gemm_xdl_fp16_v3.cpp index ad370f570e..4a969246cd 100644 --- a/example/01_gemm/gemm_xdl_fp16_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v3.cpp @@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; using ALayout = Row; -using BLayout = Row; +using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; @@ -27,17 +27,17 @@ using DeviceGemmV2Instance = ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, PassThrough, PassThrough, PassThrough, GemmDefault, - 256, - 224, 256, - 64, 8, 2, + 64, + 16, 16, + 256, 8, 8, 16, 16, - 7, 8, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 1, 1, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 8, 2, 0, - 1, 2, S<1, 32, 1, 8>, 8, - ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 8361576299..5afb3d1554 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -41,6 +41,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl BElementOp, CElementOp>; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index fe41602301..0c51a58037 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -32,11 +32,27 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + // this instance has been tested working on gfx950 + // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceComputeType = float; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp index acc5fbc515..1dec165abd 100644 --- a/example/01_gemm/gemm_xdl_fp8_bf8.cpp +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -44,6 +44,17 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp new file mode 100644 index 0000000000..266a1e9d3e --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -0,0 +1,358 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp" + +using F8 = ck::f8_t; +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F8; +using BDataType = I4; +using AccDataType = F32; +using CShuffleDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = false; + +// clang-format off +#if 0 +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 128, 128, + 256, 16, 32, + 32, 32, + 4, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 32, 1, 8>, 4, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>; + +#else +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 256, 256, + 128, 16, 32, + 32, 32, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F8, F8, PermuteA, PermuteB>; + +#endif +// clang-format on + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_preshuffled:" << b_k_n_preshuffled.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize() / + 2); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + + // weight pre-shuffle + int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8 + int NLane = gemm.GetPreShuffleParameters(); + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k); + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_preshuffled(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_preshuffled(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_preshuffled(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_preshuffled(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_preshuffled(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_f32({K, N}); + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + uint8_t i4 = 0; + + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + + float v_b = i4_to_f32_gfx9(i4); + b_k_n_f32(k, n) = v_b; + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_f32, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp new file mode 100644 index 0000000000..0575314dff --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -0,0 +1,336 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using F8 = ck::f8_t; +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F8; +using BDataType = I4; +using AccDataType = float; +using CShuffleDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 128; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 128, 128, + KPerBlock, 16, 32, + 32, 32, + 2, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; + +// clang-format on + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_f32({K, N}); + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + uint8_t i4 = 0; + + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + + float v_b = i4_to_f32_gfx9(i4); + b_k_n_f32(k, n) = v_b; + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_f32, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp new file mode 100644 index 0000000000..3b79ae9b85 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 256, + 128, 16, 16, + 16, 16, + 4, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 1, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index cc03200b9d..3237f1a61c 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,6 +33,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp index d29cb74cd6..26ea31f20b 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>; // clang-format on #else // clang-format off @@ -53,6 +53,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp index e99249389e..75971bdecf 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include @@ -52,6 +52,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp index 7d433b6145..41665a79b7 100644 --- a/example/01_gemm/gemm_xdl_streamk.cpp +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -15,7 +15,6 @@ using F16 = ck::half_t; using ALayout = Row; using BLayout = Row; -// using BLayout = Col; using CLayout = Row; using AElementOp = PassThrough; @@ -34,16 +33,33 @@ using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>; +// instance for double rate mfma instruction on gfx950 +using DeviceGemmStreamK2 = ck::tensor_operation::device::DeviceGemmXdlStreamK +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; - -// // clang-format on // clang-format on -using DeviceGemmInstance = DeviceGemmStreamK; +using DeviceGemmInstance = DeviceGemmStreamK; +using DeviceGemmInstance2 = DeviceGemmStreamK2; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; -#include "run_gemm_example.inc" +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example_streamk.inc" int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index b0f963fee5..d8672f6a0c 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -37,6 +37,17 @@ using DeviceGemmInstance = DeviceGemmInstance; using ReferenceGemmInstance = ck::tensor_operation::host:: ReferenceGemm; +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index a6f0d0bcfe..6c5d9f9fba 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -3,90 +3,6 @@ #pragma once -#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" - -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else if constexpr(std::is_same_v) - { - return 2e-1; - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -116,21 +32,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) }; auto f_get_default_stride = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(stride == 0) + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) { - // give a chance if stride is zero, return a default packed stride + // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) { - return col; + return static_cast(col); } else { - return row; + return static_cast(row); } } else - return stride; + return static_cast(stride); }; StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); @@ -143,8 +59,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) switch(config.init_method) { case 0: - ck::utils::FillConstant{static_cast(1.f)}(a_m_k); - ck::utils::FillConstant{static_cast(1.f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(1.f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(1.f)}(b_k_n); break; case 1: ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); @@ -173,6 +89,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -193,6 +110,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); @@ -203,23 +122,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); float ave_time = 0; - if constexpr(std::is_same::value && - !std::is_base_of::value) + if constexpr(std::is_same::value) { auto argument = gemm.MakeArgument( #ifdef BUILD_INT4_EXAMPLE @@ -250,61 +158,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); } - else if constexpr(std::is_same::value && - std::is_base_of::value) - { - auto argument = gemm.MakeArgument( -#ifdef BUILD_INT4_EXAMPLE - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), -#else - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), -#endif - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - problem_size.NumSKBlocks); - - if(!gemm.IsSupportedArgument(argument)) - { - std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return true; - } - - std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); - if(workspace_size != 0) - { - workspace.Realloc(workspace_size); - gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer()); - } - - ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - -#if 0 - // TODO!!!!! - if(workspace_size != 0){ - float * ws_ptr = reinterpret_cast(malloc(workspace_size)); - size_t ws_dwords = workspace_size / sizeof(float); - workspace.FromDevice(ws_ptr); - - for(size_t i = 0; i < ws_dwords; i++) { - uint32_t rere = reinterpret_cast(ws_ptr)[i]; - printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere); - } - free(ws_ptr); - } -#endif - } else { // When the Problem Type and Problem Size does not fit. @@ -325,14 +178,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << gemm.GetTypeString() << std::endl; - if(config.do_verification) + bool pass = true; + + if((config.do_verification == 1) || (config.do_verification == 3)) { + // CPU verification auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + std::cout << "Running verification on CPU." << std::endl; ref_invoker.Run(ref_argument); #ifdef BUILD_INT4_EXAMPLE @@ -346,15 +203,45 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #else c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - return ck::utils::check_err(c_m_n_device_result, - c_m_n_host_result, - "Error: Incorrect results!", - get_rtol(), - get_atol()); + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); #endif } - return true; + if((config.do_verification == 2) || (config.do_verification == 3)) + { + // GPU verification + auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; + auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); + + auto ref_argument_gpu = ref_gemm_gpu.MakeArgument( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_ref_buf.GetDeviceBuffer()), + M, + N, + K, + a_element_op, + b_element_op, + c_element_op); + + std::cout << "Running verification on GPU." << std::endl; + ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{}); + + c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + return pass == true; } bool run_gemm_example(int argc, char* argv[]) @@ -364,11 +251,3 @@ bool run_gemm_example(int argc, char* argv[]) return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); } - -bool run_gemm_streamk_example(int argc, char* argv[]) -{ - ProblemSizeStreamK problem_size; - ExecutionConfig config; - - return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); -} diff --git a/example/01_gemm/run_gemm_example_streamk.inc b/example/01_gemm/run_gemm_example_streamk.inc new file mode 100644 index 0000000000..7e43847463 --- /dev/null +++ b/example/01_gemm/run_gemm_example_streamk.inc @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/host_utility/device_prop.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + ck::utils::FillConstant{ck::type_convert(1.f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(1.f)}(b_k_n); + break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 2: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + case 3: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 4: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + break; + case 5: + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n); + break; + default: + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(a_m_k); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(b_k_n); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK; + + // do GEMM + static_assert(std::is_base_of::value && + std::is_base_of::value); + auto gemm = DeviceGemmInstance{}; + auto gemm2 = DeviceGemmInstance2{}; // instance for double rate mfma instruction + BaseStreamK* op_ptr = (ck::get_device_name() == "gfx950") ? static_cast(&gemm2) + : static_cast(&gemm); + + float ave_time = 0; + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + auto argument_ptr = op_ptr->MakeArgumentPointer( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + problem_size.NumSKBlocks); + + if(!op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::cerr << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + auto argument = argument_ptr.get(); + std::size_t workspace_size = op_ptr->GetWorkSpaceSize(argument); + if(workspace_size != 0) + { + workspace.Realloc(workspace_size); + op_ptr->SetWorkSpacePointer(argument, workspace.GetDeviceBuffer()); + } + + ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op_ptr->GetTypeString() << std::endl; + + bool pass = true; + + if((config.do_verification == 1) || (config.do_verification == 3)) + { + // CPU verification + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + std::cout << "Running verification on CPU." << std::endl; + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); +#endif + } + + if((config.do_verification == 2) || (config.do_verification == 3)) + { + // GPU verification + auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; + auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); + + auto ref_argument_gpu = ref_gemm_gpu.MakeArgument( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_ref_buf.GetDeviceBuffer()), + M, + N, + K, + a_element_op, + b_element_op, + c_element_op); + + std::cout << "Running verification on GPU." << std::endl; + ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{}); + + c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + return pass == true; +} + +bool run_gemm_streamk_example(int argc, char* argv[]) +{ + ProblemSizeStreamK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 6679f95157..2700838bcc 100644 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -103,6 +21,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto Grid_size = problem_size.Grid_size; auto Streamk_sel = problem_size.Streamk_sel; + auto reduction_strategy = problem_size.reduction_strategy; + if(reduction_strategy == ck::StreamKReductionStrategy::Atomic) + { + std::cout << "Using Atomic reduction strategy" << std::endl; + } + else + { + std::cout << "Using Parallel reduction strategy" << std::endl; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) @@ -117,7 +45,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) @@ -176,6 +104,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -196,6 +125,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) * + c_m_n_device_ref_result.mDesc.GetElementSpaceSize()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); @@ -231,7 +162,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) Grid_size, a_element_op, b_element_op, - c_element_op); + c_element_op, + reduction_strategy); if(!gemm.IsSupportedArgument(argument)) { @@ -240,8 +172,15 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) return true; } + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); + if(workspace_size != 0) + { + workspace.Realloc(workspace_size); + gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer()); + } + bool pass = true; - if(config.do_verification) + if((config.do_verification == 1) || (config.do_verification == 3)) { auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -271,6 +210,36 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #endif } + if((config.do_verification == 2) || (config.do_verification == 3)) + { + // GPU verification + auto ref_gemm_gpu = ReferenceGemmInstanceGPU{}; + auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker(); + + auto ref_argument_gpu = ref_gemm_gpu.MakeArgument( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_ref_buf.GetDeviceBuffer()), + M, + N, + K, + a_element_op, + b_element_op, + c_element_op); + + std::cout << "Running verification on GPU." << std::endl; + ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{}); + + c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data()); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_device_ref_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + if(config.time_kernel) { ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); @@ -284,7 +253,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; + << " GB/s, " << gemm.GetTypeString() + << (reduction_strategy == ck::StreamKReductionStrategy::Atomic ? " (Atomic)" + : " (Reduction)") + << std::endl; } return pass; } diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index ad7238f0dd..4adb6f896b 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -115,21 +33,21 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) }; auto f_get_default_stride = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(stride == 0) + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1 || stride == 0) { - // give a chance if stride is zero, return a default packed stride + // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) { - return col; + return static_cast(col); } else { - return row; + return static_cast(row); } } else - return stride; + return static_cast(stride); }; StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); @@ -228,7 +146,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) } bool pass = true; - if(config.do_verification) + if((config.do_verification == 1) || (config.do_verification == 3)) { auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -261,7 +179,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) if(config.time_kernel) { ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4}); + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 100, true, 4}); std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index be47665a26..562936418b 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp index de7af85fb3..67b3e646f7 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -34,7 +34,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm::RLayout; struct ExecutionConfig final { bool do_verification = true; - int init_method = 1; + int init_method = 2; bool time_kernel = false; }; diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index cebfeb51d6..d61aee81a4 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, Tensor conv_output_device(conv_output_g_n_k_wos_desc); Tensor r0_device(r0_desc); + std::cout << "input: " << conv_input.mDesc << std::endl; + std::cout << "weight: " << conv_weight.mDesc << std::endl; + std::cout << "output: " << conv_output_device.mDesc << std::endl; + std::cout << "reduction: " << r0_device.mDesc << std::endl << std::endl; + switch(config.init_method) { case 0: break; case 1: ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input); - ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_weight); + ck::utils::FillUniformDistributionIntegerValue{-1, 1}(conv_weight); + break; + case 2: + ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input); + ck::utils::FillUniformDistribution{-1, 1}(conv_weight); break; default: - ck::utils::FillUniformDistribution{-5, 5}(conv_input); - ck::utils::FillUniformDistribution{-5, 5}(conv_weight); + ck::utils::FillUniformDistribution{-8, 7}(conv_input); + ck::utils::FillUniformDistribution{-1, 1}(conv_weight); } DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize()); @@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, return false; } + // XXX: DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle will not initialize r0. + r0_device_buf.SetValue(ck::NumericLimits::Lowest()); + const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - const std::size_t flop = problem_size.GetFlops(); - const std::size_t num_btype = problem_size.GetByte(); + if(config.time_kernel) + { + const std::size_t flop = problem_size.GetFlops(); + const std::size_t num_btype = problem_size.GetByte(); - const float tflops = static_cast(flop) / 1.E9 / avg_time; - const float gb_per_sec = num_btype / 1.E6 / avg_time; - std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << conv.GetTypeString() << std::endl; + const float tflops = static_cast(flop) / 1.E9 / avg_time; + const float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv.GetTypeString() << std::endl; + } + else + { + std::cout << "FINISHED: " << conv.GetTypeString() << std::endl; + } if(config.do_verification) { @@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, BElementOp{}, PassThrough{}); + std::cout << "\nRunning verification on CPU." << std::endl; ref_invoker.Run(ref_argument); Tensor r0_host(r0_device.mDesc); @@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, conv_output_device_buf.FromDevice(conv_output_device.mData.data()); r0_device_buf.FromDevice(r0_device.mData.data()); - return ck::utils::check_err(conv_output_device, - conv_output_host, - "Error: incorrect results! (Matrix E)", - 1e-5f, - 1e-4f) && - ck::utils::check_err( - r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f); + auto pass = ck::utils::check_err(conv_output_device, + conv_output_host, + "Error: incorrect results! (Matrix E)", + 1e-3f, + 1e-3f); + pass = + pass && ck::utils::check_err( + r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f); + if(pass) + std::cout << "Verification on CPU: PASS" << std::endl; + + return pass; } return true; diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp index ecff7b4713..117a18e3bd 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); for(int j = 0; j < NumDMatrices; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); for(int j = 0; j < NumDMatrices; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); } } } @@ -246,7 +246,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co // do GEMM auto argument = gemm.MakeArgument( p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); - gemm.SetKBatchSize(argument, config.k_batch); + gemm.SetKBatchSize(&argument, config.k_batch); if(!gemm.IsSupportedArgument(argument)) { throw std::runtime_error( @@ -257,7 +257,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); invoker.Run(argument, StreamConfig{nullptr, false, 1}); diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 965a0e7e37..63a2aea0b3 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -91,7 +91,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co { auto group_count = problem_size.group_count; - using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments; + using KernelArguments = ck::tensor_operation::device::GroupedGemmKernelArgument; using GemmDesc = ck::tensor_operation::device::GemmDesc; // GEMM shape @@ -141,8 +141,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co a_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count); - d_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count); + d_tensors_device.resize(group_count); // reserve and update vector size std::size_t flop = 0, num_btype = 0; @@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); for(int j = 0; j < NumDs; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); for(int j = 0; j < NumDs; ++j) { - d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential{}); } } } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp index a193fc39ba..5bdc993192 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } - d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; @@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm.GetDeviceKernelArgSize(&argument), hipMemcpyHostToDevice)); - gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_kernel_args_dev.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 1a2bcfb33e..6806bd1886 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } @@ -239,7 +239,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 0a63a29843..8418c10f5e 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } @@ -240,7 +240,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); gemm.SetKBatch(argument, config.k_batch); invoker.Run(argument, StreamConfig{nullptr, false}); diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 320870e0de..7186c22233 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once struct ProblemSize final @@ -18,6 +21,7 @@ struct ExecutionConfig final bool do_verification = true; int init_method = 1; bool time_kernel = false; + bool async_hargs = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) @@ -124,8 +128,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } @@ -168,9 +172,30 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto argument = gemm.MakeArgument( p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); - DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); + std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument); + std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument); - gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + DeviceMem gemm_workspace, gemm_kargs; + void* gemm_hargs; + + // The following is necessary since TwoStage kernel is using additional memory both + // for Workspace and kernel arguments. + if(kargs_size > 0) + { + gemm_kargs.Realloc(kargs_size); + gemm.SetDeviceKernelArgs(&argument, gemm_kargs.GetDeviceBuffer()); + } + if(workspace_size > 0 && workspace_size != kargs_size) + { + gemm_workspace.Realloc(workspace_size); + gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); + } + if(config.async_hargs && hargs_size > 0) + { + hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size)); + gemm.SetHostKernelArgsPointer(&argument, gemm_hargs); + } if(!gemm.IsSupportedArgument(argument)) { @@ -179,7 +204,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - invoker.Run(argument, StreamConfig{nullptr, false}); + if(!config.async_hargs) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + } + else + { + hipStream_t stream0 = nullptr; + hip_check_error(hipStreamCreate(&stream0)); + + hipEvent_t event0 = nullptr; + hip_check_error(hipEventCreate(&event0)); + + invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); + + hip_check_error(hipEventSynchronize(event0)); + hip_check_error(hipStreamSynchronize(stream0)); + } bool pass = true; if(config.do_verification) @@ -247,18 +288,25 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.async_hargs = std::stoi(argv[4]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: async hargs (0=n0, 1=yes)\n"); exit(0); } diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index 2f6533d448..a46eaa4816 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -198,7 +198,7 @@ int main() throw std::runtime_error("wrong! this device_op instance does not support this problem"); } - // init reducetion buffer to 0 + // init reduction buffer to 0 r0_device_buf.SetZero(); r1_device_buf.SetZero(); diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 94ed129dc0..03ba0a65df 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index 497ea19e11..6fbaee7dba 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -13,3 +13,9 @@ add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bw add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16) + +add_example_executable(example_grouped_conv_bwd_weight_v3_xdl_bf16 grouped_conv_bwd_weight_v3_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_xdl_bf16) + +add_example_executable(example_grouped_conv_bwd_weight_v3_xdl_fp16 grouped_conv_bwd_weight_v3_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_v3_xdl_fp16) diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp new file mode 100644 index 0000000000..c0168bf117 --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_bf16.cpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" + +using InDataType = BF16; +// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory +using WeiDataType = F32; +using OutDataType = BF16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + // clang-format on + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 32, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl + // clang-format off + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp new file mode 100644 index 0000000000..4b9fd356df --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_v3_xdl_fp16.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 32, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<2, 0, 1>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + false, // ABlockLdsAddExtraM + S<4, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<2, 0, 1>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + false, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp index 90d80f9f03..277fea0272 100644 --- a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -175,8 +175,8 @@ int main(int argc, char* argv[]) b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); } c0_n_bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 4cb45be7c9..d515720944 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -9,6 +9,12 @@ add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16) +add_example_executable(example_batched_gemm_xdl_bf16_v3 batched_gemm_xdl_bf16_v3.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16_v3) + +add_example_executable(example_batched_gemm_xdl_fp8_rowwise_v3 batched_gemm_xdl_fp8_rowwise_v3.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp8_rowwise_v3) + add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) @@ -16,3 +22,6 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) endif() + +add_example_executable(example_batched_gemm_xdl_fp16int4_b_scale_v3 batched_gemm_xdl_fp16int4_b_scale_v3.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16int4_b_scale_v3) diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp new file mode 100644 index 0000000000..548500518f --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmDefault, + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 0, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 0, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<8>, // CDEShuffleBlockTransferScalarPerVectors + ck::BlockGemmPipelineScheduler::Intrawave, // BlockGemmPipelineScheduler + ck::BlockGemmPipelineVersion::v3 // BlockGemmPipelineVersion + >; + +#include "run_batched_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp new file mode 100644 index 0000000000..42171bcdb7 --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp @@ -0,0 +1,82 @@ +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = F32; +using CShuffleDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto PermuteA = false; +static constexpr bool PermuteB = false; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 256; + +// clang-format off +using DeviceBatchedGemmV2Instance = + ck::tensor_operation::device::DeviceBatchedGemm_Xdl_CShuffleV3_BScale< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 16, 64, + KPerBlock, 8, 32, + 16, 16, + 1, 1, + S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 16, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; +// clang-format on + +using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm; +#include "run_batched_gemm_example_fp16int4_b_scale.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_fp16_int4_b_scale_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp new file mode 100644 index 0000000000..84f92eba8e --- /dev/null +++ b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply; + +using ADataType = F8; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmDefault, + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 0, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 0, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + S<8, 8, 1>, // CDEShuffleBlockTransferScalarPerVectors + ck::BlockGemmPipelineScheduler::Interwave, // BlockGemmPipelineScheduler + ck::BlockGemmPipelineVersion::v1, // BlockGemmPipelineVersion + F8 // ComputeTypeA + >; + +#include "run_batched_gemm_example_rowwise.inc" + +int main(int argc, char* argv[]) { return !run_batched_gemm_rowwise_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index 21934add31..741512bf00 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -210,7 +210,39 @@ bool run_batched_gemm_example(int argc, char* argv[]) problem_size.M = 256 * (dis(gen) + 1); problem_size.N = 128 * (dis(gen) + 1); - problem_size.K = 64 * (dis(gen) + 2); + problem_size.K = 128 * (dis(gen) + 2); + + problem_size.batch_count = 2; + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 8) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + problem_size.batch_count = std::stoi(argv[7]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("optinal\n"); + printf("arg4-7: M = %d N = %d K = %d Batch = %d\n", + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.batch_count); + exit(0); + } problem_size.stride_A = problem_size.K; problem_size.stride_B = problem_size.K; @@ -220,21 +252,5 @@ bool run_batched_gemm_example(int argc, char* argv[]) problem_size.batch_stride_B = problem_size.K * problem_size.N; problem_size.batch_stride_C = problem_size.M * problem_size.N; - problem_size.batch_count = 16; - - if(argc == 4) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); - } - return run_batched_gemm(problem_size, config); } diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc new file mode 100644 index 0000000000..3582bc5e33 --- /dev/null +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -0,0 +1,579 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include + +#pragma once +struct ProblemSize final +{ + ck::index_t M = 128; + ck::index_t N = 128; + ck::index_t K = 384; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t batch_stride_A = M * K; + ck::index_t batch_stride_B = K * N; + ck::index_t batch_stride_C = M * N; + + // Batched Gemm count + ck::index_t batch_count = 2; + + // Split K count + ck::index_t KBatch = 1; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count, + KBatch] = problem_size; + + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t batch_BScale_Stride = + ((K + Scale_Block_K - 1) / Scale_Block_K) * ((N + Scale_Block_N - 1) / Scale_Block_N); + + Tensor a_g_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + Tensor b_g_k_n_permute( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + Tensor b1_g_k_n( + f_host_tensor_descriptor(batch_count, + (K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + batch_BScale_Stride, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "b1_g_k_n: " << b1_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize() / + 2); + DeviceMem b1_g_scale_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * + c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + printf("a_g_m_k size: %zu, b_g_k_n size: %zu, b1_g_k_n size: %zu, c_g_m_n size: %zu\n", + a_g_m_k.mDesc.GetElementSpaceSize(), + b_g_k_n_permute.mDesc.GetElementSpaceSize(), + b1_g_k_n.mDesc.GetElementSpaceSize(), + c_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + printf("Permute B\n"); + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int bs = 0; bs < batch_count; bs++) + { + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_g_k_n_permute(bs * batch_stride_B + j * N * K1 + i * K1 + jj) = + b_g_k_n(bs * batch_stride_B + i * K + (j * K1 + jj)); + } + } + } + } + } + else + { + b_g_k_n_permute = b_g_k_n; + } + + // vector pk_i4x4 permute + for(int bs = 0; bs < batch_count; bs++) + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_g_k_n_permute(bs, j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_g_k_n_permute(bs, j + 6, i) = i4x2; + } + } + } + } + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b_g_k_n_device_buf.ToDevice(b_g_k_n_permute.mData.data()); + b1_g_scale_device_buf.ToDevice(b1_g_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceBatchedGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + stride_A, + stride_B, + stride_C, + Scale_Stride_BN, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_BScale_Stride, + static_cast(b1_g_scale_device_buf.GetDeviceBuffer()), + batch_count, // batch count + KBatch, // split K count + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + Tensor b_g_k_n_dequant({batch_count, K, N}); + if(config.do_verification) + { + float v_b = 0; + for(int bs = 0; bs < batch_count; bs++) + { + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_g_k_n_dequant(bs, k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_g_k_n(bs, k / Scale_Block_K, n / Scale_Block_N)); + } + } + } + + auto ref_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_g_m_k, + b_g_k_n_dequant, + c_g_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + hip_check_error(hipDeviceSynchronize()); + + c_g_m_n_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_g_m_n_device_result, + c_g_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + +#if 0 + // print A matrix + printf("A matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&a_g_m_k(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < K; j++) + { + printf("%.2f,", static_cast(a_g_m_k(bs, i, j))); + } + printf("\n"); + } + } + + // print B matrix original + printf("B matrix original:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b_g_k_n(bs, 0, 0))); + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_g_k_n(bs, k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + printf("%d,", static_cast(i4)); + } + printf("\n"); + } + } + + // print B matrix + printf("B matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b_g_k_n_dequant(bs, 0, 0))); + for(int i = 0; i < K; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(b_g_k_n_dequant(bs, i, j))); + } + printf("\n"); + } + } + + // print B scale matrix + printf("B Scale matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&b1_g_k_n(bs, 0, 0))); + for(int i = 0; i < (K + Scale_Block_K - 1) / Scale_Block_K; i++) + { + for(int j = 0; j < (N + Scale_Block_N - 1) / Scale_Block_N; j++) + { + printf("%.2f, ", static_cast(b1_g_k_n(bs, i, j))); + } + printf("\n"); + } + } + + // print C matrix + printf("C matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf( + "batch %d -> Address: %p\n", bs, static_cast(&c_g_m_n_device_result(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(c_g_m_n_device_result(bs, i, j))); + } + printf("\n"); + } + } + + printf("C reference matrix:\n"); + for(int bs = 0; bs < batch_count; bs++) + { + printf("batch %d -> Address: %p\n", bs, static_cast(&c_g_m_n_host_result(bs, 0, 0))); + for(int i = 0; i < M; i++) + { + for(int j = 0; j < N; j++) + { + printf("%.2f, ", static_cast(c_g_m_n_host_result(bs, i, j))); + } + printf("\n"); + } + } +#endif + + return pass; +} + +bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 15); + + problem_size.M = 128 * (dis(gen) + 1); + problem_size.N = 128 * (dis(gen) + 1); + problem_size.K = 256 * (dis(gen) + 2); + + problem_size.batch_count = 2; + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 7) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + if(argc >= 8) + { + problem_size.batch_count = std::stoi(argv[7]); + } + + if(argc >= 9) + { + problem_size.KBatch = std::stoi(argv[8]); + } + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + problem_size.stride_A = problem_size.K; + problem_size.stride_B = problem_size.K; + problem_size.stride_C = problem_size.N; + + problem_size.batch_stride_A = problem_size.M * problem_size.K; + problem_size.batch_stride_B = problem_size.K * problem_size.N; + problem_size.batch_stride_C = problem_size.M * problem_size.N; + + return run_batched_gemm(problem_size, config); +} diff --git a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc new file mode 100644 index 0000000000..778be8ffd7 --- /dev/null +++ b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include + +#pragma once + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t stride_A = K; + ck::index_t stride_B = K; + ck::index_t stride_C = N; + + ck::index_t stride_D0 = 0; + ck::index_t stride_D1 = 0; + + ck::index_t batch_stride_A = M * K; + ck::index_t batch_stride_B = K * N; + ck::index_t batch_stride_C = M * N; + + ck::index_t batch_stride_D0 = N; + ck::index_t batch_stride_D1 = M; + + ck::index_t batch_count = 16; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +bool run_batched_gemm_rowwise(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto& [M, + N, + K, + stride_A, + stride_B, + stride_C, + stride_D0, + stride_D1, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D0, + batch_stride_D1, + batch_count] = problem_size; + + // GEMM shape + auto f_host_tensor_descriptor = [](std::size_t batch_count_, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + } + else + { + return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + } + }; + + Tensor a_g_m_k( + f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, ALayout{})); + Tensor b_g_k_n( + f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, BLayout{})); + Tensor d0_g_m_n( + f_host_tensor_descriptor(batch_count, M, N, stride_D0, batch_stride_D0, D0Layout{})); + Tensor d1_g_m_n( + f_host_tensor_descriptor(batch_count, M, N, stride_D1, batch_stride_D1, D1Layout{})); + Tensor e_g_m_n_device_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl; + std::cout << "d1_g_m_n: " << d1_g_m_n.mDesc << std::endl; + std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + d0_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_g_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_g_m_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + d0_device_buf.ToDevice(d0_g_m_n.mData.data()); + d1_device_buf.ToDevice(d1_g_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = + gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + batch_count, + stride_A, + stride_B, + {stride_D0, stride_D1}, + stride_C, + batch_stride_A, + batch_stride_B, + {batch_stride_D0, batch_stride_D1}, + batch_stride_C, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + invoker.Run(argument, StreamConfig{nullptr, false}); + bool pass = true; + + if(config.do_verification) + { + c_device_buf.FromDevice(e_g_m_n_device_result.mData.data()); + + Tensor c_g_m_n({batch_count, M, N}); + + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + Tensor e_g_m_n_host_result( + f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, ELayout{})); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int b = 0; b < batch_count; ++b) + { + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_g_m_n_host_result(b, m, n), + c_g_m_n(b, m, n), + d0_g_m_n(b, m, n), + d1_g_m_n(b, m, n)); + } + } + } + + pass = ck::utils::check_err( + e_g_m_n_device_result, e_g_m_n_host_result, "Error: Incorrect results c"); + } + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * batch_count * M * N * K; + std::size_t num_btype = sizeof(ADataType) * batch_count * M * K + + sizeof(BDataType) * batch_count * K * N + + sizeof(EDataType) * batch_count * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + return pass ? 0 : 1; +} + +bool run_batched_gemm_rowwise_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 15); + + problem_size.M = 256 * (dis(gen) + 1); + problem_size.N = 128 * (dis(gen) + 1); + problem_size.K = 128 * (dis(gen) + 2); + + problem_size.batch_count = 2; + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 8) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + problem_size.batch_count = std::stoi(argv[7]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("optinal\n"); + printf("arg4-7: M = %d N = %d K = %d Batch = %d\n", + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.batch_count); + exit(0); + } + + problem_size.stride_A = problem_size.K; + problem_size.stride_B = problem_size.K; + problem_size.stride_C = problem_size.N; + + problem_size.stride_D0 = 0; + problem_size.stride_D1 = 0; + + problem_size.batch_stride_A = problem_size.M * problem_size.K; + problem_size.batch_stride_B = problem_size.K * problem_size.N; + problem_size.batch_stride_C = problem_size.M * problem_size.N; + + problem_size.batch_stride_D0 = problem_size.N; + problem_size.batch_stride_D1 = problem_size.M; + + return run_batched_gemm_rowwise(problem_size, config); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index e3370b880b..627e20e245 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -32,6 +32,56 @@ using BiasLayout = typename LayoutSettingSelector::BiasLayout; template using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; +// instance for double rate mfma on gfx950 (vs gfx942) +template +using DeviceConvFwdInstance2 = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 1, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 16, 1, 16>, + 4>; +// instance for gfx942- template using DeviceConvFwdInstance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< @@ -54,14 +104,14 @@ using DeviceConvFwdInstance = 1, // 256, // BlockSize 128, // MPerBlock - 256, // NPerBlock - 16, // KPerBlock + 128, // NPerBlock + 32, // KPerBlock 4, // AK1 4, // BK1 32, // MPerXdl 32, // NPerXdl 2, // MXdlPerWave - 4, // NXdlPerWave + 2, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -184,40 +234,67 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config, copy(conv_param.input_right_pads_, input_right_pads); // do Conv - auto conv = DeviceConvFwdInstance{}; - auto invoker = conv.MakeInvoker(); - auto argument = - conv.MakeArgument(in_device_buf.GetDeviceBuffer(), - wei_device_buf.GetDeviceBuffer(), - std::array{bias_device_buf.GetDeviceBuffer(), - residual_device_buf.GetDeviceBuffer()}, - out_device_buf.GetDeviceBuffer(), - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - std::array, 2>{ - {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}}, - std::array, 2>{ - {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}}, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + using BaseGroupedConvFwdMultipleABD = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + ck::Tuple, + OutKernelDataType, + InElementOp, + WeiElementOp, + OutElementOp, + InKernelDataType, // AComputeDataType + InKernelDataType>; // BComputeDataType - if(!conv.IsSupportedArgument(argument)) + static_assert( + std::is_base_of>::value && + std::is_base_of>::value); + + auto conv = DeviceConvFwdInstance{}; // instance for gfx942- + auto conv2 = DeviceConvFwdInstance2{}; // instance for double rate mfma instruction + // on gfx950 + BaseGroupedConvFwdMultipleABD* op_ptr = + (ck::get_device_name() == "gfx950") ? static_cast(&conv2) + : static_cast(&conv); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + auto argument_ptr = op_ptr->MakeArgumentPointer( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer(), + residual_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 2>{ + {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}}, + std::array, 2>{ + {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!op_ptr->IsSupportedArgument(argument_ptr.get())) { throw std::runtime_error( "wrong! device_conv with the specified compilation parameters does " "not support this Conv problem"); } - float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, config.time_kernel}); std::size_t flop = conv_param.GetFlops(); std::size_t num_btype = conv_param.GetByte(); @@ -225,7 +302,7 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config, float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_btype / 1.E6 / avg_time; std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << conv.GetTypeString() << std::endl; + << op_ptr->GetTypeString() << std::endl; if(config.do_verification) { diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 8b648a7f73..3e8c9afd9d 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) -if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) endif() diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index f329146728..d545508680 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -150,7 +150,7 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[]) break; default: a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 27602e2313..1514fc48b3 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -157,7 +157,7 @@ int run(int argc, char* argv[]) break; default: a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index fa76faea84..2b02069e65 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -118,7 +118,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc index 2e77479bcc..e0ccb6dad1 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -153,7 +153,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc index 9ff4c56e06..0ad031cc71 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -178,7 +178,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index ea1e2734a6..cdfd86dff4 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -152,7 +152,7 @@ int run(int argc, char* argv[]) break; default: a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); - b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc index 609d085299..7ac29f33ca 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -156,7 +156,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc index b05915c07f..fb9b1b0bd7 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -156,7 +156,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc index 3fdaaebb0f..2cb69380e5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { @@ -173,7 +173,7 @@ int run(int argc, char* argv[]) b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); } diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc index e3690984ab..cb1d3410c9 100644 --- a/example/35_splitK_gemm/run_splitK_gemm_example.inc +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + #pragma once struct ProblemSize final @@ -66,8 +69,8 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); } DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp index 97a3f89e5e..fc55019fc4 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on #else diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index 36dcf58d70..f27dc60541 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. /* Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] @@ -60,14 +60,14 @@ struct AddAddRelu { const ck::half_t x = c + d0 + d1; - ck::tensor_operation::element_wise::Relu{}.template operator()(e, x); + ck::tensor_operation::element_wise::Relu{}.operator()(e, x); } __host__ __device__ void operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const { const float x = c + (d0 + d1); - ck::tensor_operation::element_wise::Relu{}.template operator()(e, x); + ck::tensor_operation::element_wise::Relu{}.operator()(e, x); } }; @@ -377,7 +377,7 @@ int main(int argc, char* argv[]) break; default: a0_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential{}); d00_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); d01_g_m_n.GenerateTensorValue(GeneratorTensor_1{1}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 8a0474156c..6af8ac6488 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -41,7 +41,7 @@ struct ExecutionConfig final { bool do_verification = true; int init_method = 1; - bool time_kernel = true; + bool time_kernel = false; }; #define DefaultConvParams \ diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index 8ab56b21a6..c5c5a84b67 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) -if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) endif() diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index afbf948683..867493465d 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -5,3 +5,4 @@ add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permu add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp) add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp) +add_example_executable(elementwise_scale_permute_amax_2D_fp16_fp8 elementwise_scale_permute_amax_2D_fp16_fp8.cpp) diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp new file mode 100644 index 0000000000..9431a8cde4 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/utility/reduction_enums.hpp" + +using F16 = ck::half_t; +using F32 = float; +using F8 = ck::f8_t; + +using InputDataType = F16; +using ScaleDataType = F32; +using OutputDataType = F8; + +static constexpr ck::index_t NumDim = 2; + +constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename ck::reduce_binary_operator::opType; + +struct ScalePassThrough +{ + ScalePassThrough(const float alpha = 1.f) : alpha_(alpha) {} + + __host__ __device__ constexpr void + operator()(OutputDataType& y0, OutputDataType& y1, const InputDataType& x0) const + { + y0 = ck::type_convert(ck::type_convert(x0) * alpha_); + y1 = y0; + } + + const ScaleDataType alpha_; +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using UnaryAbs = ck::tensor_operation::element_wise::UnaryAbs; + +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + ScalePassThrough, // Elementwise + NumDim, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8, 1>>; // OutScalarPerVectorSeq + +using DeviceReduceInstance = + ck::tensor_operation::device::DeviceReduceMultiBlock; // OutDstVectorSize + +void reference_scale_permute_amax(Tensor& input, + Tensor& host_output_scaled_casted_transposed, + Tensor& host_output_scaled_casted, + Tensor& host_output_amax, + const float scale) +{ + ScalePassThrough out_element_op(scale); + const ck::index_t M = input.GetLengths()[0]; + const ck::index_t K = input.GetLengths()[1]; + for(ck::index_t m = 0; m < M; m++) + { + for(ck::index_t k = 0; k < K; k++) + { + OutputDataType y0, y1; + out_element_op(y0, y1, input(m, k)); + + host_output_scaled_casted(m, k) = y0; + host_output_scaled_casted_transposed(m, k) = y1; + const OutputDataType y_fabs = + ck::type_convert(ck::math::abs(ck::type_convert(y0))); + host_output_amax(0) = ck::type_convert(ck::math::max( + ck::type_convert(y_fabs), ck::type_convert(host_output_amax(0)))); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + bool time_kernel = true; + + const float scale = 2.f; + + ck::index_t M = 1024; + ck::index_t K = 1024; + + if(argc == 3) + { + M = std::stoi(argv[1]); + K = std::stoi(argv[2]); + } + + std::array dims = {M, K}; + std::array in_strides = {K, 1}; + std::array out_strides = {1, M}; + + Tensor input(dims, in_strides); + Tensor output_scaled_casted_transposed(dims, out_strides); + Tensor output_scaled_casted(dims, in_strides); + Tensor output_amax({1}); + + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem input_dev_buf(sizeof(InputDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem output_scaled_casted_transposed_dev_buf( + sizeof(OutputDataType) * output_scaled_casted_transposed.mDesc.GetElementSpaceSize()); + DeviceMem output_scaled_casted_dev_buf(sizeof(OutputDataType) * + output_scaled_casted.mDesc.GetElementSpaceSize()); + DeviceMem output_amax_dev_buf(sizeof(OutputDataType) * output_amax.mDesc.GetElementSpaceSize()); + + input_dev_buf.ToDevice(input.mData.data()); + + std::array inputs = {input_dev_buf.GetDeviceBuffer()}; + std::array outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(), + output_scaled_casted_dev_buf.GetDeviceBuffer()}; + + std::cout << "Input: " << input.mDesc << std::endl; + std::cout << "Scale: " << scale << std::endl; + std::cout << "Output scaled casted transposed: " << output_scaled_casted_transposed.mDesc + << std::endl; + std::cout << "Output scaled casted: " << output_scaled_casted.mDesc << std::endl; + std::cout << "Output amax: " << output_amax.mDesc << std::endl; + + auto launch_transpose_scale = [&]() { + auto transposeScale = DeviceElementwisePermuteInstance{}; + auto argument = transposeScale.MakeArgumentPointer(dims, + {in_strides}, + {out_strides, in_strides}, + inputs, + outputs, + ScalePassThrough{scale}); + + if(!transposeScale.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto transposeScale_invoker_ptr = transposeScale.MakeInvokerPointer(); + return transposeScale_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + }; + + auto launch_reduce = [&]() { + auto reduce = DeviceReduceInstance{}; + auto reduce_argument_ptr = + reduce.MakeArgumentPointer(dims, + in_strides, + {1}, // Output Lengths + {1}, // Output Strides + {0, 1}, // Reduce Dims + static_cast(1.f), + static_cast(0.f), + output_scaled_casted_dev_buf.GetDeviceBuffer(), + nullptr, + output_amax_dev_buf.GetDeviceBuffer(), + nullptr, + UnaryAbs{}, + PassThrough{}); + + if(!reduce.IsSupportedArgument(reduce_argument_ptr.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + return invoker_ptr->Run(reduce_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + }; + + float ave_time = launch_transpose_scale(); + ave_time += launch_reduce(); + std::cout << "Perf: " << ave_time << " ms" << std::endl; + bool pass = true; + + if(do_verification) + { + Tensor host_output_scaled_casted_transposed(dims, out_strides); + Tensor host_output_scaled_casted(dims, in_strides); + Tensor host_output_amax({1}); + + reference_scale_permute_amax(input, + host_output_scaled_casted_transposed, + host_output_scaled_casted, + host_output_amax, + scale); + + output_scaled_casted_transposed_dev_buf.FromDevice( + output_scaled_casted_transposed.mData.data()); + output_scaled_casted_dev_buf.FromDevice(output_scaled_casted.mData.data()); + output_amax_dev_buf.FromDevice(output_amax.mData.data()); + + pass &= ck::utils::check_err(output_scaled_casted_transposed.mData, + host_output_scaled_casted_transposed.mData, + "Error: Incorrect results scaled transposed", + 1e-3, + 1e-3); + pass &= ck::utils::check_err(output_scaled_casted.mData, + host_output_scaled_casted.mData, + "Error: Incorrect results scaled", + 1e-3, + 1e-3); + pass &= ck::utils::check_err( + output_amax.mData, host_output_amax.mData, "Error: Incorrect results amax", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp index a90a6340a4..392cb155cb 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -248,7 +248,7 @@ int main(int argc, char* argv[]) d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); break; default: - a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp index 742fd5547a..055d253042 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -194,9 +194,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b1_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); diff --git a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp index 809c1a956c..1ba8133ea7 100644 --- a/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp +++ b/example/59_grouped_gemm_multi_ABD/grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp @@ -184,9 +184,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + a1_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } d0_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index ab136d99ba..79fafed4eb 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(convscale_add) add_subdirectory(convscale_reduce) add_subdirectory(multi_AB) add_subdirectory(unary) +add_subdirectory(dynamic_unary) add_custom_target(example_convnd_activ_xdl) # ScaleAdd ScaleAdd Relu diff --git a/example/62_convnd_activ/binary/CMakeLists.txt b/example/62_convnd_activ/binary/CMakeLists.txt index 9d90cdd244..b9584be89c 100644 --- a/example/62_convnd_activ/binary/CMakeLists.txt +++ b/example/62_convnd_activ/binary/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convinvscale/CMakeLists.txt b/example/62_convnd_activ/convinvscale/CMakeLists.txt index 07f42075bd..7aae090674 100644 --- a/example/62_convnd_activ/convinvscale/CMakeLists.txt +++ b/example/62_convnd_activ/convinvscale/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt index 9264da24a6..26f6c1b168 100644 --- a/example/62_convnd_activ/convscale/CMakeLists.txt +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp index 978221f8e1..bf560f8a43 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_convscale_common.hpp @@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification, { case 0: break; case 1: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // values generated: -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5 + in.GenerateTensorValue(GeneratorTensor_2{-5, 6}); + wei.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); break; default: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); } DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); diff --git a/example/62_convnd_activ/convscale_add/CMakeLists.txt b/example/62_convnd_activ/convscale_add/CMakeLists.txt index 40cfd74aa4..b2e0eecb58 100644 --- a/example/62_convnd_activ/convscale_add/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_add/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt index ff9020a707..739c855ae4 100644 --- a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/convscale_relu/CMakeLists.txt b/example/62_convnd_activ/convscale_relu/CMakeLists.txt index 95589cedcb..c3241aecf2 100644 --- a/example/62_convnd_activ/convscale_relu/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_relu/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt new file mode 100644 index 0000000000..8441030945 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt @@ -0,0 +1,45 @@ +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_dynamic_unary_xdl) + # Sigmoid + add_example_executable(example_convnd_fwd_xdl_dynamic_sigmoid_fp16 convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_sigmoid_fp16) + # Tanh + add_example_executable(example_convnd_fwd_xdl_dynamic_tanh_fp16 convnd_fwd_xdl_dynamic_tanh_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_tanh_fp16) + # Relu + add_example_executable(example_convnd_fwd_xdl_dynamic_relu_fp16 convnd_fwd_xdl_dynamic_relu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_relu_fp16) + # SoftRelu + add_example_executable(example_convnd_fwd_xdl_dynamic_softrelu_fp16 convnd_fwd_xdl_dynamic_softrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_softrelu_fp16) + # Abs + add_example_executable(example_convnd_fwd_xdl_dynamic_abs_fp16 convnd_fwd_xdl_dynamic_abs_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_abs_fp16) + # Pow + add_example_executable(example_convnd_fwd_xdl_dynamic_pow_fp16 convnd_fwd_xdl_dynamic_pow_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_pow_fp16) + # Clipped Relu + add_example_executable(example_convnd_fwd_xdl_dynamic_clippedrelu_fp16 convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_clippedrelu_fp16) + # Leaky Relu + add_example_executable(example_convnd_fwd_xdl_dynamic_leakyrelu_fp16 convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_leakyrelu_fp16) + # Elu + add_example_executable(example_convnd_fwd_xdl_dynamic_elu_fp16 convnd_fwd_xdl_dynamic_elu_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_elu_fp16) + # Swish + add_example_executable(example_convnd_fwd_xdl_dynamic_swish_fp16 convnd_fwd_xdl_dynamic_swish_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_swish_fp16) + # PassThrough + add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fwd_xdl_dynamic_passthrough_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16) + # Logistic + add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp) + add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16) + + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp new file mode 100644 index 0000000000..ed31be19ee --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using DynamicElementOp = ck::tensor_operation::element_wise::DynamicUnaryOp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGroupedConvNDActivInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + DynamicElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_abs_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_abs_fp16.cpp new file mode 100644 index 0000000000..8fa455c62e --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_abs_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::UnaryAbs out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp new file mode 100644 index 0000000000..239a21525b --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::ClippedRelu out_element_op(0.f, 1.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_elu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_elu_fp16.cpp new file mode 100644 index 0000000000..23a094af70 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_elu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Elu out_element_op(2.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp new file mode 100644 index 0000000000..fe4b80a681 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::LeakyRelu out_element_op(0.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_logistic_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_logistic_fp16.cpp new file mode 100644 index 0000000000..756c07ed85 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_logistic_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Logistic out_element_op(1.0f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_passthrough_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_passthrough_fp16.cpp new file mode 100644 index 0000000000..6588ec5044 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_passthrough_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::PassThrough out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_pow_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_pow_fp16.cpp new file mode 100644 index 0000000000..90f00a166a --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_pow_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Power out_element_op(4.f, 1.f, 2.f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_relu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_relu_fp16.cpp new file mode 100644 index 0000000000..830297cb56 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_relu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Relu out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp new file mode 100644 index 0000000000..b143b4a4eb --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Sigmoid out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_softrelu_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_softrelu_fp16.cpp new file mode 100644 index 0000000000..83ba0f7f8c --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_softrelu_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::SoftRelu out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_swish_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_swish_fp16.cpp new file mode 100644 index 0000000000..e862d1120a --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_swish_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::Swish out_element_op(1.0f); + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_tanh_fp16.cpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_tanh_fp16.cpp new file mode 100644 index 0000000000..a91fc7ce30 --- /dev/null +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_xdl_dynamic_tanh_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_dynamic_unary_common.hpp" + +#include "../run_convnd_activ_dynamic_example.inc" + +int main(int argc, char* argv[]) +{ + + ck::tensor_operation::element_wise::TanH out_element_op; + return !run_convnd_example(argc, argv, out_element_op); +} diff --git a/example/62_convnd_activ/multi_AB/CMakeLists.txt b/example/62_convnd_activ/multi_AB/CMakeLists.txt index c89c82d384..149bd6f03e 100644 --- a/example/62_convnd_activ/multi_AB/CMakeLists.txt +++ b/example/62_convnd_activ/multi_AB/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/62_convnd_activ/run_convnd_activ_dynamic_example.inc b/example/62_convnd_activ/run_convnd_activ_dynamic_example.inc new file mode 100644 index 0000000000..4e90cf9366 --- /dev/null +++ b/example/62_convnd_activ/run_convnd_activ_dynamic_example.inc @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +template +bool run_convnd_example(int argc, char* argv[], const OutElementOp& out_element_op) +{ + print_helper_msg(); + + bool do_verification = true; + // Use floats for SoftRelu by default to avoid overflow after e^x. + int init_method = + std::is_same_v ? 2 : 1; + bool time_kernel = false; + + // Following shapes are selected to avoid overflow. Expect inf in case of + // size increase for some elementwise ops. + ck::utils::conv::ConvParam conv_param{ + 3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + const auto run = [&]() { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + if(conv_param.num_dim_spatial_ == 3) + { + return run(); + } + + return false; +} diff --git a/example/62_convnd_activ/unary/CMakeLists.txt b/example/62_convnd_activ/unary/CMakeLists.txt index 3470e9b945..36b4ffc9f4 100644 --- a/example/62_convnd_activ/unary/CMakeLists.txt +++ b/example/62_convnd_activ/unary/CMakeLists.txt @@ -1,4 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index d39114013b..9f4c43338e 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,3 +1,73 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp16_bpreshuffle gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) +set(EXAMPLE_COMPILE_OPTIONS) +# Open it when SGBPack branch landed on mainline +# list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental") +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) +add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) +add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) +add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp) +add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp) + +list(APPEND gpu_list gfx942 gfx950) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) + add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) + if(hip_VERSION_FLAT LESS_EQUAL 600342132) + set(EXAMPLE_COMPILE_OPTIONS) + check_cxx_compiler_flag("-mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1" HAS_MAX_ILP_SCHEDULING_STRATEGY) + if(HAS_MAX_ILP_SCHEDULING_STRATEGY) + list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + endif() + example_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + endif() + set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") + example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS}) + example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) + set(target 1) + endif() +endforeach() + +set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") +set(BLOCKSCALE_GEMM_OPTIONS ) +check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) +check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) + +if(hip_VERSION_FLAT LESS 600443483 OR hip_VERSION_FLAT GREATER_EQUAL 700000000) + if(HAS_MISCHED_BOTTOMUP) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1") + elseif(HAS_MISCHED_PRERA_DIRECTION) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup") + endif() +else() + if(HAS_MISCHED_BOTTOMUP) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1") + elseif(HAS_MISCHED_PRERA_DIRECTION) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-prera-direction=bottomup") + endif() +endif() + +check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL) +if(HAS_MAX_OCCUPANCY_EXPERIMENTAL) + list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental) +endif() +# list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1") +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) + +example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) +example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp new file mode 100644 index 0000000000..69803c7eeb --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(F16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()(BF16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +void preShuffleBuffer(const F16* src, F16* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(F16); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + // kernel 1: 256->32x128x128 + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, + 32, 128, 128, + 8, 8, + 32, 32, + 1, 1, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F16>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, false, 1}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index cb4f60764e..352d373ae5 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -69,18 +69,21 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = MultiplyMultiply; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 // clang-format off -///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| -///###### RRR - ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; -///###### RCR - < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; + , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) @@ -229,7 +232,7 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50, true, 50}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 2568754648..5aa978fbf0 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -55,7 +55,7 @@ using CDEElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t Scale_Block_M = 128; +static constexpr ck::index_t Scale_Block_M = 1; static constexpr ck::index_t Scale_Block_N = 128; static constexpr ck::index_t Scale_Block_K = 128; @@ -71,7 +71,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_ 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, + 1, 2, S<1, 32, 1, 8>, S<8>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on @@ -80,11 +80,12 @@ int main(int argc, char* argv[]) bool do_verification = true; int init_method = 1; bool time_kernel = false; + bool flush_cache = true; // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; ck::index_t StrideA = K; ck::index_t StrideB = K; @@ -100,7 +101,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 10) + else if(argc == 8) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -110,16 +111,19 @@ int main(int argc, char* argv[]) N = std::stoi(argv[5]); K = std::stoi(argv[6]); - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideE = std::stoi(argv[9]); + flush_cache = std::stoi(argv[7]); + + StrideA = K; + StrideB = K; + StrideE = N; } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); exit(0); } @@ -182,9 +186,15 @@ int main(int argc, char* argv[]) b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); break; case 4: - a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); break; default: @@ -194,6 +204,16 @@ int main(int argc, char* argv[]) b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); } #endif +#if 0 + for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){ + float row_sum = .0; + for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){ + printf("%lf ",a1_m_k(im, ik)); + row_sum += a1_m_k(im, ik); + } + printf("sum: %lf\n", row_sum * 128); + } +#endif DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); @@ -205,7 +225,6 @@ int main(int argc, char* argv[]) a1_device_buf.ToDevice(a1_m_k.mData.data()); b0_device_buf.ToDevice(b0_k_n.mData.data()); b1_device_buf.ToDevice(b1_k_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -240,12 +259,24 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); - std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + float ave_time = .0; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -253,8 +284,6 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - if(do_verification) { Tensor c_m_n({M, N}); diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp new file mode 100644 index 0000000000..d64266bccf --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using A1Layout = Col; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 1, S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + exit(0); + } + + // Transpose the AScale tensor for better performance + ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AK, + A1Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = 0.0f; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + Tensor c_m_n({M, N}); + Tensor a_m_k({M, K}); + Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + +#if 1 + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } +#endif + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp new file mode 100644 index 0000000000..fe1eca51b0 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -0,0 +1,395 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using B0DataType = FP8; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(F16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()(BF16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, + 256, 256, 128, + 16, 16, + 16, 16, + 16, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + ck::index_t Warmup = 50; + ck::index_t Repeat = 50; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else if(argc == 14) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + + Warmup = std::stoi(argv[12]); + Repeat = std::stoi(argv[13]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + printf("arg10 to 11: Warmup, Repeat\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + size_t total_size = + (M * K * sizeof(A0DataType) + N * K * sizeof(B0DataType) + M * sizeof(D0DataType) + + N * sizeof(D1DataType) + M * N * sizeof(EDataType)); + int rotate_buf_num = + ck::math::min(size_t(Repeat), ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size)); + + float ave_time = invoker.Run( + argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, true, rotate_buf_num}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp new file mode 100644 index 0000000000..cbbd37408e --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp @@ -0,0 +1,341 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int; +using F16 = ck::half_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = I8; +using B0DataType = I8; +using AccDataType = I32; +using CShuffleDataType = F16; +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const ck::half_t x0_f = ck::type_convert(c) * d0 * d1; + + e = x0_f; + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const ck::half_t x0_f = c * d0 * d1; + + e = x0_f; + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RRR + ///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; +///###### RCR + < Row, Col, DsLayout, ELayout, + A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, + 64, 128, 256, + 16, 16, + 32, 32, + 1, 2, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + exit(0); + } + do_verification = false; + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-25, 25}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 25}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{0, 200}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{0, 200}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + hipStream_t stream; + hip_check_error(hipStreamCreate(&stream)); + + float ave_time = invoker.Run(argument, StreamConfig{stream, time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp new file mode 100644 index 0000000000..9fe9fdde78 --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -0,0 +1,486 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using B0DataType = F8; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// for gate, a_scale, b_scale +struct MulABScale +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const + { + (void)d0; + (void)d1; + e = ck::type_convert(c); + } +}; + +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); +static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); + +static constexpr ck::index_t BLOCKSIZE = 256; +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm + // clang-format off + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, NXDLPerWave, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + 2, 2, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; + +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t N = 4096; + ck::index_t K = 6144; + ck::index_t experts = 8; + ck::index_t sorted_tile_num = 256; + ck::index_t valid_tile_num = 256; + ck::index_t tokens = 16384; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 5: N, K, tokens\n"); + exit(0); + } + + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{1, 1, 1}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); + max_token_id.mData = {valid_size}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / (valid_tile_num / experts); + } + + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0, 1}); + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + d0_device_buf.ToDevice(d0_t_n.mData.data()); + d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; + std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + + sizeof(B0DataType) * K * N * 2 * experts + + sizeof(EDataType) * valid_tile_num * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + d0_t_n, + b0_e_n_k, + d1_e_n, + c_t_k_n, + d2_e_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + const int e = expert_ids(m / MPerBlock); + for(int n = 0; n < N; ++n) + { + cde_element_op(e_t_n_host_result(t, topk_id, n), + c_t_k_n(t, topk_id, n), + d0_t_n(t, n), + d1_e_n(e, n), + d2_e_n(e, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp new file mode 100644 index 0000000000..c5328226ff --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +// using EDataType = F16; +using EDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const float& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d2); + } +}; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * static_cast(K) + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = true; + +#if 0 +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); +static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); +static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave; +static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave; +static constexpr ck::index_t BLOCKSIZE = 256; + +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale + // clang-format off + < Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, NXDLPerWave, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; +#else +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, 128, 128, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; +#endif +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +#if 1 + // GEMM shape + ck::index_t N = 4096; + ck::index_t K = 6144; + ck::index_t experts = 8; + ck::index_t topk = 2; + // ck::index_t sorted_tile_num = 515; + // ck::index_t valid_tile_num = 512; + // ck::index_t tokens = 8192; + // ck::index_t sorted_tile_num = 15; + // ck::index_t valid_tile_num = 13; + ck::index_t sorted_tile_num = 259; + ck::index_t valid_tile_num = 256; + ck::index_t tokens = 4096; +#else + // deepseek + ck::index_t N = 2048; + ck::index_t K = 7168; + ck::index_t experts = 256; + ck::index_t topk = 8; + ck::index_t tokens = 4096; + ck::index_t sorted_tile_num = 261; + ck::index_t valid_tile_num = 256; +#endif + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0}; + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); + max_token_id.mData = {valid_size}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, + (K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N * 2}, + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + a1_device_buf.ToDevice(a1_t_k.mData.data()); + b1_device_buf.ToDevice(b1_e_n_k.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; + std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + + sizeof(B0DataType) * K * N * 2 * experts + + sizeof(EDataType) * valid_tile_num * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s.\n" + << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor a_t_k({tokens, K}); + Tensor b_e_n_k({experts, K, N * 2}); + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + // handle scale before ref. + for(int t = 0; t < tokens; ++t) + { + for(int k = 0; k < K; ++k) + { + a_t_k(t, k) = ck::type_convert(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K); + } + } + + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N * 2; ++n) + { + b_e_n_k(e, k, n) = ck::type_convert(b0_e_n_k(e, k, n)) * + b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N); + } + } + } + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm1BlockScale; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a_t_k, + b_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp new file mode 100644 index 0000000000..9e80a2ca35 --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -0,0 +1,549 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using B0DataType = I4; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// for gate, a_scale, b_scale +struct MulABScale +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c); +#else + e = ck::type_convert(c); +#endif + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c); +#else + e = ck::type_convert(c); +#endif + } +}; + +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +static constexpr bool MulRoutedWeight = true; + +using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true + +// using CDEElementOp = MulABScale; // combine MulRoutedWeight = true + +#if 1 +void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) +{ + int KPack = 32; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex / 2] = src[(n * K + k) / 2]; + } + } +} +#endif + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< + Row, Col, DsLayout, ELayout, + A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, MPerBlock, 64, 128, + 16, 32, + 16, 16, + 8, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + 2, 1, S<1, 32, 1, 8>, S<8, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + ck::index_t N = 14336; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + ck::index_t tokens = 644; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 5: N, K, tokens\n"); + exit(0); + } + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{1, 1, 1}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); + max_token_id.mData = {valid_size}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; + std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + d0_device_buf.ToDevice(d0_t_n.mData.data()); + d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + +#if 1 + preShuffleBuffer(b0_e_n_k.mData.data(), + b0_preshuffled.mData.data(), + N * experts, + K, + device_op.GetPreShuffleParameters()); +#else + // weight pre-shuffle + int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8 + int NLane = device_op.GetPreShuffleParameters(); + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int e = 0; e < experts; ++e) + { + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + b0_preshuffled(e, outputIndex % K, outputIndex / K) = b0_e_n_k(e, k, n); + } + } + } +#endif + +#if CK_USE_PK4_LAYOUT_SHUFFLE + // vector pk_i4x4 permute + for(int e = 0; e < experts; e++) + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b0_preshuffled(e, j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b0_preshuffled(e, j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b0_preshuffled(e, j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b0_preshuffled(e, j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b0_preshuffled(e, j + 6, i) = i4x2; + } + } + } + } +#endif + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; + std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(EDataType) * valid_tile_num * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + d0_t_n, + b0_e_n_k, + d1_e_n, + c_t_k_n, + d2_e_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + const int e = expert_ids(m / MPerBlock); + for(int n = 0; n < N; ++n) + { + cde_element_op(e_t_n_host_result(t, topk_id, n), + c_t_k_n(t, topk_id, n), + d0_t_n(t, n), + d1_e_n(e, n), + d2_e_n(e, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp new file mode 100644 index 0000000000..6a3986ea32 --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -0,0 +1,460 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using B0DataType = F8; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr ck::index_t MPerBlock = 256; +static constexpr ck::index_t BLOCKSIZE = 256; +static constexpr ck::index_t MXDLPerWave = 16; +static constexpr ck::index_t NXDLPerWave = 4; +static constexpr ck::index_t NPerBlock = 256; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); + +static constexpr ck::index_t CShuffleNLane = 32; +static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 2; +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; +static constexpr bool PerTokenQuant = true; +static constexpr bool MulRoutedWeight = true; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + // kernel 1: 256->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, EDataType>; + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, NXDLPerWave, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>; + // kernel 2: 128->32x128x128 + // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; + +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + ck::index_t N = 4096; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t sorted_tile_num = 133; + ck::index_t valid_tile_num = 128; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + ck::index_t tokens = 16384; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = PerTokenQuant ? std::array{1, 1, 0} + : std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + // max_token_id.mData[0] = valid_size; + // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts); + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor d0_t_n( + HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; + std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + d0_device_buf.ToDevice(d0_t_n.mData.data()); + d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t num_btype = sizeof(A0DataType) * tokens * K * topk + + sizeof(B0DataType) * K * N * experts + + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm2; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + b0_e_n_k, + d0_t_n, + d1_e_n, + d2_e_n, + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp new file mode 100644 index 0000000000..354957c0d1 --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -0,0 +1,541 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +using EDataType = F16; +// using EDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +// using DsLayoutGate = ck::Tuple; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + // for real kernel use + + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const float& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d2); + } +}; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * static_cast(K) + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; +static constexpr bool MulRoutedWeight = true; + +#if 0 +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t BLOCKSIZE = 256; +static constexpr ck::index_t MXDLPerWave = 2; +static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); + +static constexpr ck::index_t CShuffleNLane = 16; +static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 2; +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; + +// clang-format off + +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + AK1, BK1, + MNPerXDL, MNPerXDL, + MXDLPerWave, NXDLPerWave, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; + +#else +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, 128, 128, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; +#endif +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // tokens = 1 + // topk = 1 + // experts = 8 + // per expert: + + constexpr ck::index_t valid_tile_num = + 26; // 13 for 128; 52 for 32; 4096 for ds // > token * topk / MPerBlock + constexpr ck::index_t sorted_tile_num = valid_tile_num + 3; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; +#if 1 + // GEMM shape + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; +#else + // deepseek + ck::index_t N = 2048; + ck::index_t K = 7160; + ck::index_t experts = 256; + ck::index_t tokens = 1; + ck::index_t topk = 8; +#endif + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0}; + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + + max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + // int eids[] = {0, 1, 3, 3, 3}; + // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + // int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + // 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + // 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, + // 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, + // 7, 7, + // 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k(HostTensorDescriptor( + {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N}, + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + for(auto i = 0; i < N * K; i++) + { + b0_e_n_k.mData[i] = ck::type_convert(static_cast(0.1)); + b0_e_n_k.mData[i + N * K] = ck::type_convert(static_cast(0.2)); + } + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + a1_device_buf.ToDevice(a1_t_k_k.mData.data()); + b1_device_buf.ToDevice(b1_e_n_k.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t num_btype = sizeof(A0DataType) * tokens * K * topk + + sizeof(B0DataType) * K * N * experts + + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s.\n" + << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor a_t_k_k({tokens, topk, K}); + Tensor b_e_n_k({experts, K, N}); + Tensor c_t_n({tokens, N}); + + for(int t = 0; t < tokens; ++t) + { + for(int tk = 0; tk < topk; ++tk) + { + for(int k = 0; k < K; ++k) + { + a_t_k_k(t, tk, k) = ck::type_convert(a0_t_k_k(t, tk, k)) * + a1_t_k_k(t, tk, k / Scale_Block_K); + } + } + } + + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + b_e_n_k(e, k, n) = ck::type_convert(b0_e_n_k(e, k, n)) * + b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N); + } + } + } + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm2BlockScale; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a_t_k_k, + b_e_n_k, + d2_e_n, + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp new file mode 100644 index 0000000000..3745e3d0af --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -0,0 +1,495 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using B0DataType = I4; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * 16); +#else + e = ck::type_convert(c); +#endif + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * d0 * d1 * d2 * 16); +#else + e = ck::type_convert(c * d0 * d1 * d2); +#endif + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) +{ + int KPack = 32; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex / 2] = src[(n * K + k) / 2]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t BLOCKSIZE = 256; +static constexpr ck::index_t MXDLPerWave = 8; +static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t CShuffleNLane = 32; +static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 2; +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = true; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm + // clang-format off + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, + AK1, BK1, + MNPerXDL, MNPerXDL, + MXDLPerWave, NXDLPerWave, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + ck::index_t N = 4096; + ck::index_t K = 14336; + ck::index_t experts = 8; + ck::index_t sorted_tile_num = 19; + ck::index_t valid_tile_num = 16; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + ck::index_t tokens = 512; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 3) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); + Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; + std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d2_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + d0_device_buf.ToDevice(d0_t_n.mData.data()); + d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + preShuffleBuffer(b0_e_n_k.mData.data(), + b0_preshuffled.mData.data(), + N * experts, + K, + device_op.GetPreShuffleParameters()); + +#if CK_USE_PK4_LAYOUT_SHUFFLE + // vector pk_i4x4 permute + for(int e = 0; e < experts; e++) + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b0_preshuffled(e, j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b0_preshuffled(e, j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b0_preshuffled(e, j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b0_preshuffled(e, j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b0_preshuffled(e, j + 6, i) = i4x2; + } + } + } + } +#endif + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t num_btype = sizeof(A0DataType) * tokens * K * topk + + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + b0_e_n_k, + d0_t_n, + d1_e_n, + d2_e_n, + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/66_complex_contraction_bilinear/CMakeLists.txt b/example/66_complex_contraction_bilinear/CMakeLists.txt new file mode 100644 index 0000000000..c417caf8e7 --- /dev/null +++ b/example/66_complex_contraction_bilinear/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_complex_contraction_bilinear_xdl_fp32 complex_contraction_bilinear_xdl_fp32.cpp) +add_example_executable(example_complex_contraction_bilinear_xdl_fp64 complex_contraction_bilinear_xdl_fp64.cpp) + diff --git a/example/66_complex_contraction_bilinear/README.md b/example/66_complex_contraction_bilinear/README.md new file mode 100644 index 0000000000..04d92da0d2 --- /dev/null +++ b/example/66_complex_contraction_bilinear/README.md @@ -0,0 +1,11 @@ +# Instructions for ```example_complex_contraction_bilinear_xdl_fp32``` + +## Run +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_contraction_bilinear_xdl_fp32 1 1 1 +``` + + diff --git a/example/66_complex_contraction_bilinear/common_instances.hpp b/example/66_complex_contraction_bilinear/common_instances.hpp new file mode 100644 index 0000000000..480ca5a0af --- /dev/null +++ b/example/66_complex_contraction_bilinear/common_instances.hpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using F64 = double; + +template +using S = ck::Sequence; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Generic instances for fp32, fp16 and bf16 data types. +template +// clang-format off +using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +// Fp64 instances. +template +// clang-format off +using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp new file mode 100644 index 0000000000..619279c47b --- /dev/null +++ b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_complex_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); } diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp new file mode 100644 index 0000000000..f3528d0901 --- /dev/null +++ b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DDataType = F64; +using DsDataType = ck::Tuple; +using EDataType = F64; +using ComputeDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_complex_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); } diff --git a/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc new file mode 100644 index 0000000000..82ac0a15e1 --- /dev/null +++ b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +int run_complex_contraction_bilinear_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + // For Real Part of Complex Tensor + Tensor a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides); + + Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides); + + // For Imaginary Part of Complex Tensor + Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides); + + Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides); + + // Intermediate E tensor Definition + Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl; + std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl; + std::cout << "d_ms_ns_re: " << d_ms_ns_re.mDesc << std::endl; + std::cout << "e_ms_ns_re: " << e_ms_ns_host_result_re.mDesc << std::endl; + + std::cout << "a_ms_ks_img: " << a_ms_ks_img.mDesc << std::endl; + std::cout << "b_ns_ks_img: " << b_ns_ks_img.mDesc << std::endl; + std::cout << "d_ms_ns_img: " << d_ms_ns_img.mDesc << std::endl; + std::cout << "e_ms_ns_img: " << e_ms_ns_host_result_img.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + + a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + + default: + a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + + break; + } + + DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_re(sizeof(EDataType) * + e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); + + DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_img(sizeof(EDataType) * + e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); + + // Intermediate Value For E Real and Img + DeviceMem e_device_buf_re1(sizeof(EDataType) * + e_ms_ns_device_result_re.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf_img1(sizeof(EDataType) * + e_ms_ns_device_result_img.mDesc.GetElementSpaceSize()); + + a_device_buf_re.ToDevice(a_ms_ks_re.mData.data()); + b_device_buf_re.ToDevice(b_ns_ks_re.mData.data()); + d_device_buf_re.ToDevice(d_ms_ns_re.mData.data()); + + a_device_buf_img.ToDevice(a_ms_ks_img.mData.data()); + b_device_buf_img.ToDevice(b_ns_ks_img.mData.data()); + d_device_buf_img.ToDevice(d_ms_ns_img.mData.data()); + + // set zero + e_device_buf_re.SetZero(); + e_device_buf_img.SetZero(); + + // set zero for intermediate values + e_device_buf_re1.SetZero(); + e_device_buf_img1.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + // For real Intermediate Value re_1 + + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument_re1 = + op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), + b_device_buf_re.GetDeviceBuffer(), + std::array{d_device_buf_re.GetDeviceBuffer()}, + e_device_buf_re1.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_re1)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel}); + + alpha = -1.f; + beta = 1.f; + + a_element_op = AElementOp{}; + b_element_op = BElementOp{}; + cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + // For real Intermediate Value re_2 + // auto op = DeviceOpInstance{}; + // auto invoker = op.MakeInvoker(); + auto argument_re2 = + op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), + b_device_buf_img.GetDeviceBuffer(), + std::array{e_device_buf_re1.GetDeviceBuffer()}, + e_device_buf_re.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_re2)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel}); + + alpha = 1.f; + beta = 1.f; + + a_element_op = AElementOp{}; + b_element_op = BElementOp{}; + cde_element_op = CDEElementOp{alpha, beta}; + + auto argument_img1 = + op.MakeArgument(a_device_buf_re.GetDeviceBuffer(), + b_device_buf_img.GetDeviceBuffer(), + std::array{d_device_buf_img.GetDeviceBuffer()}, + e_device_buf_img1.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_img1)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_img1 = invoker.Run(argument_img1, StreamConfig{nullptr, time_kernel}); + + alpha = 1.f; + beta = 1.f; + + auto argument_img2 = + op.MakeArgument(a_device_buf_img.GetDeviceBuffer(), + b_device_buf_re.GetDeviceBuffer(), + std::array{e_device_buf_img1.GetDeviceBuffer()}, + e_device_buf_img.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument_img2)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K * 2; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2; + + float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf_re.FromDevice(e_ms_ns_device_result_re.mData.data()); + e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data()); + + auto isRealOk = 0; + auto isImgOk = 0; + + if(do_verification) + { + // Real Part Verification + Tensor c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument_re = ref_op.MakeArgument( + a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_re); + + alpha = 1.f; + beta = 1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1), + c_ms_ns_host_result_re(m0, m1, n0, n1), + d_ms_ns_re(m0, m1, n0, n1)); + } + } + } + } + + alpha = 1.f; + beta = -1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + auto ref_argument_re1 = ref_op.MakeArgument( + a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_re1); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1), + e_ms_ns_host_result_re(m0, m1, n0, n1), + c_ms_ns_host_result_re1(m0, m1, n0, n1)); + } + } + } + } + + isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; + + // Img Part Verification + Tensor c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + + auto ref_argument_img = ref_op.MakeArgument( + a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_img); + + alpha = 1.f; + beta = 1.f; + + cde_element_op = CDEElementOp{alpha, beta}; + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1), + c_ms_ns_host_result_img(m0, m1, n0, n1), + d_ms_ns_img(m0, m1, n0, n1)); + } + } + } + } + + auto ref_argument_img1 = ref_op.MakeArgument( + a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument_img1); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1), + e_ms_ns_host_result_img(m0, m1, n0, n1), + c_ms_ns_host_result_img1(m0, m1, n0, n1)); + } + } + } + } + + isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1; + + return (isRealOk && isImgOk); + } + + return 0; +} diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt new file mode 100644 index 0000000000..14b648c9f8 --- /dev/null +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -0,0 +1,68 @@ +add_custom_target(example_gemm_mx) + +add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) + +add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) + +# TODO: Fix RRR +# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) +# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) + +add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp6) + +add_example_executable(example_gemm_mx_bf6 gemm_mx_bf6.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_bf6) + +add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp4) + +add_example_executable(example_gemm_mx_fp4_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp4_bpreshuffle) + +add_example_executable(example_moe_gemm1_xdl_mx_fp4_bns moe_gemm1_xdl_mx_fp4_bns.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns) + +add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns) + +add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4) + +add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4) + +add_example_executable(example_moe_gemm1_xdl_mx_fp4_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bpreshuffle) + +add_example_executable(example_moe_gemm2_xdl_mx_fp4_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bpreshuffle) + +set(FP4_MXGEMM_OPTIONS) +list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1") +example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) + +# mx moe B no-shuffling + scale shuffling +example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) + +# mx moe B no-shuffling + scale shuffling (async loads) +example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) + +# mx moe B shuffling + scale shuffling (async loads) +example_compile_options(example_moe_gemm1_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) + +set(FP8_MXGEMM_OPTIONS) +list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32") +example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS}) + +set(FP6_MXGEMM_OPTIONS) +list(APPEND FP6_MXGEMM_OPTIONS -mavx512f) +example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS}) diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md new file mode 100644 index 0000000000..007c934b7e --- /dev/null +++ b/example/67_gemm_microscaling/README.md @@ -0,0 +1,27 @@ +# GEMM Examples for Microscaling Formats + +## example_gemm_mx_fp8 + +Custom verification parameters: +```bash +# arg1: verification (0=no, 1=CPU) +# arg2: initialization (0=constant values, 1=integer values, 2=decimal values) +# arg3: time kernel (0=no, 1=yes) +# arg4: verbosity (0=no info, 1=verbose info) +# arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC +# arg11: KBatch +# arg12: warmup runs pre-timing +# arg13: repeat run count for timing +./bin/example_gemm_mx_fp8 1 1 0 1 +``` + +Custom tensor shapes: +```bash +./bin/example_gemm_mx_fp8 1 2 1 0 256 256 512 -1 -1 -1 1 10 10 +``` + +Default invocation: +```bash +# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0 +./bin/example_gemm_mx_fp8 +``` \ No newline at end of file diff --git a/example/67_gemm_microscaling/gemm_mx_bf6.cpp b/example/67_gemm_microscaling/gemm_mx_bf6.cpp new file mode 100644 index 0000000000..34810c2961 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_bf6.cpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::bf6x16_pk_t; +using BDataType = ck::bf6x16_pk_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t DataPackedSize = 16; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 bf6 = 16 bf6x16_pk_t + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 1, // AK1 + 1, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_bf8.cpp new file mode 100644 index 0000000000..58f2dcb010 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_bf8.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::bf8_t; +using BDataType = ck::bf8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 128, // BlockSize: Thread block size + 128, // MPerBlock + 32, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp new file mode 100644 index 0000000000..2d0585c880 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -0,0 +1,564 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using MFMA = ck::tensor_layout::gemm::MFMA; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ck::type_convert; + +struct ExecutionConfig final +{ + int do_verification = 1; // (0=no, 1=CPU) + int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values) + bool time_kernel = false; // (0=no, 1=yes) + int verbosity = 0; // (0=no info, 1=verbose info) + int warm_up = 10; + int repeat = 10; +}; + +struct ProblemSizeSplitK final +{ + + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = -1; + ck::index_t StrideB = -1; + ck::index_t StrideC = -1; + + ck::index_t KBatch = 1; +}; + +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeSplitK& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + } + else if(argc >= 11) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.verbosity = std::stoi(argv[4]); + + problem_size.M = std::stoi(argv[5]); + problem_size.N = std::stoi(argv[6]); + problem_size.K = std::stoi(argv[7]); + + problem_size.StrideA = std::stoi(argv[8]); + problem_size.StrideB = std::stoi(argv[9]); + problem_size.StrideC = std::stoi(argv[10]); + + if(argc >= 12) + { + problem_size.KBatch = std::stoi(argv[11]); + config.warm_up = std::stoi(argv[12]); + config.repeat = std::stoi(argv[13]); + } + } + else + { + std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl + << "arg2: initialization (0=constant values, 1=integer values, 2=decimal values)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl + << "arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC" << std::endl + << "arg11: KBatch" << std::endl + << "arg12: warmup runs pre-timing" << std::endl + << "arg13: repeat run count for timing" << std::endl; + + return false; + } + + return true; +} + +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, + // 2-k))); + + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K_pk; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +template +bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) +{ + constexpr bool BPreShuffle = ck::is_same_v; + using BRefLayout = ck::conditional_t; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if constexpr(std::is_same_v) + return HostTensorDescriptor({row, col}, {stride, 1}); + else + return HostTensorDescriptor({row, col}, {1, stride}); + }; + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + return static_cast(col); + else + return static_cast(row); + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + if(K % ck::packed_size_v != 0 || K % ck::packed_size_v != 0) + { + throw std::runtime_error("wrong! K must be multiple of packed size."); + }; + + // Hardcode scale layouts as per pipeline assumptions + // TODO: Allow user to specify scale layouts + using AScaleLayout = Row; + using BScaleLayout = Col; + + auto Scale_Padded_M = ck::math::integer_least_multiple(M, ScaleBlockSize); + auto Scale_Stride_AM = + f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + auto b_k_n = + std::make_shared>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); + auto b_input = b_k_n; + if constexpr(BPreShuffle) + b_input = std::make_shared>( + f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size + + // scales for A and B + Tensor a_m_k_scale(f_host_tensor_descriptor( + Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + Tensor b_k_n_scale( + f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + // shuffled scales for A and B + Tensor a_shuffled_scale(f_host_tensor_descriptor( + Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + Tensor b_shuffled_scale( + f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification + Tensor c_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // device result downloaded to host + + if(config.verbosity >= 0) + { + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n->mDesc << std::endl; + std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; + std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; + } + + auto a_data_element = [](float x) { + if constexpr(ck::is_same_v) + return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); + else + return ck::type_convert(x); + }; + auto b_data_element = [](float x) { + if constexpr(ck::is_same_v) + return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); + else + return ck::type_convert(x); + }; + + using int_distr = std::uniform_int_distribution; + using float_distr = std::uniform_real_distribution; + switch(config.init_method) + { + case 0: // Initializations for development and debugging + + ck::utils::FillConstant{a_data_element(0.5f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); + + ck::utils::FillConstant{b_data_element(2.0f)}(*b_k_n); + ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n_scale); + + if(config.verbosity > 0) + { + std::cout << "Init A = {0.5}" << std::endl; + std::cout << "Init A scale = {2.0}" << std::endl; + std::cout << "Init B = {2.0}" << std::endl; + std::cout << "Init B scale = {0.5}" << std::endl; + std::cout << "Expect C = {K}" << std::endl; + } + break; + + case 1: + a_m_k.GenerateTensorDistr( + int_distr{-5, 5}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-5,5] + b_k_n->GenerateTensorDistr(int_distr{-5, 5}); // Z[-5,5] + static_assert(ck::is_same_v); + a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} + break; + + case 2: + a_m_k.GenerateTensorDistr( + float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2] + a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); + + b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0}); + b_k_n_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); + break; + + default: + if(config.verbosity > 0) + { + std::cout << "NOTE: No input data initialization." << std::endl; + } + } + + preShuffleScaleBuffer>(a_m_k_scale.mData.data(), + a_shuffled_scale.mData.data(), + Scale_Padded_M, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); + if constexpr(BPreShuffle) + { + int NPerXdl = 16; // Fixed 16 + preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl); + } + + if(config.verbosity > 0) + std::cout << "Device memory allocation..." << std::endl; + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize()); + + if(config.verbosity > 0) + std::cout << "Upload data to device..." << std::endl; + a_device_buf.ToDevice(a_m_k.mData.data()); + a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data()); + b_device_buf.ToDevice(b_input->mData.data()); + b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data()); + + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // run GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong!\n" + "Provided combination of compilation and runtime parameters is " + "not consistent with the supported device_gemm arguments."); + } + + std::size_t total_size = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() + + a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() + + a_shuffled_scale.GetElementSpaceSizeInBytes() + + b_shuffled_scale.GetElementSpaceSizeInBytes(); + const auto total_cnt = ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size); + const int rotating_count = std::max(1, std::min(config.repeat, static_cast(total_cnt))); + if(config.verbosity > 0) + { + std::cout << "Computing GEMM on device..." << std::endl << std::endl; + } + + float ave_time = invoker.Run(argument, + StreamConfig{nullptr, + config.time_kernel, + config.verbosity, + config.warm_up, + config.repeat, + rotating_count > 1, + rotating_count}); + + bool res_verified = true; + if(config.do_verification > 0) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Computing GEMM on host..." << std::endl; + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + *b_k_n, + b_k_n_scale, + c_m_n_host_result, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + if(config.verbosity > 0) + { + std::cout << "Done." << std::endl; + std::cout << "Comparing results..." << std::endl; + } + + res_verified = + res_verified && + ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1); + + if(config.verbosity > 0 && res_verified) + std::cout << "Verification Successful!" << std::endl; + } + else + { + if(config.verbosity > 0) + std::cout << "Done." << std::endl; + } + + if(config.time_kernel) + { + // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of + // partial sums(K/ScaleBlockSize)] + // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + std::size_t num_btype = + sizeof(ADataType) * M * K / ck::packed_size_v + + sizeof(BDataType) * K * N / ck::packed_size_v + sizeof(CDataType) * M * N + + sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = static_cast(num_btype) / 1e6f / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + return res_verified; +} + +template +bool run_mx_gemm_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && + run_mx_gemm(problem_size, config); +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp4.cpp b/example/67_gemm_microscaling/gemm_mx_fp4.cpp new file mode 100644 index 0000000000..65fbe3491a --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp4.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f4x2_pk_t; +using BDataType = ck::f4x2_pk_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +// AB DataType: f4x2_pk_t +// Mathmatically, all numbers are represented as f4x2. +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 256, // MPerBlock + 256, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 8, // NXdlPerWave + S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..6e1efd266b --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f4x2_pk_t; +using BDataType = ck::f4x2_pk_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = MFMA; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +// AB DataType: f4x2_pk_t +// Mathmatically, all numbers are represented as f4x2. +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 512, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 8, // NXdlPerWave + S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 4, // CShuffleNXdlPerWavePerShuffle + S<1, 8, 1, 32>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlockW + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp6.cpp b/example/67_gemm_microscaling/gemm_mx_fp6.cpp new file mode 100644 index 0000000000..615980082d --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp6.cpp @@ -0,0 +1,99 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f6x16_pk_t; +using BDataType = ck::f6x16_pk_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / ck::packed_size_v; // K dimension size per block + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Number of threads per block + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 1, // AK1 number of elements to read at a time when transferring from global memory to LDS + 1, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp new file mode 100644 index 0000000000..e6fe791178 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp new file mode 100644 index 0000000000..fdc4ace471 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::bf8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + 256, // KPerBlock + 16, // AK1 + 8, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<32, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp new file mode 100644 index 0000000000..aaf0cb3891 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t BlockSize = 256; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, BlockSize, + MPerBlock, NPerBlock, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{0.5f}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.5f}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp new file mode 100644 index 0000000000..24ab326391 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -0,0 +1,545 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t BlockSize = 256; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, BlockSize, + MPerBlock, NPerBlock, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 4096; + ck::index_t K = 6144; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{0.5f}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.5f}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..08ed8e11fb --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,574 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// B preshuffle +void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + I64 tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K_pk; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 64, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + preShuffleBuffer(b0_e_n_k.mData.data(), + b0_preshuffled.mData.data(), + N * 2 * experts, + K, + device_op.GetPreShuffleParameters()); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp new file mode 100644 index 0000000000..1b8a7a16e3 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -0,0 +1,542 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 8: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + // d2_e_n.savetxt("weight.txt", "int"); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp new file mode 100644 index 0000000000..829bf9af24 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -0,0 +1,526 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..efbd0f0c03 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,584 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// B preshuffle +void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + I64 tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K_pk; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 8, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = 13; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 8: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + // A, B Scale preshuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + preShuffleBuffer(b0_e_n_k.mData.data(), + b0_preshuffled.mData.data(), + N * experts, + K, + device_op.GetPreShuffleParameters()); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index f9e62a2356..56d709f41b 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,6 +5,14 @@ include_directories(BEFORE add_custom_target(examples) + +# list of examples that are labelled as REGRESSION_EXAMPLE for make regression (runtime more than 30 seconds) +# all other tests are labelled as SMOKE_EXAMPLE +set(REGRESSION_EXAMPLES + example_sparse_embedding3_forward_layernorm +) + + function(add_example_dependencies EXAMPLE_NAME FILE_NAME) if(FILE_NAME) add_dependencies(EXAMPLE_NAME FILE_NAME) @@ -12,86 +20,110 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME) endfunction(add_example_dependencies EXAMPLE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME) - message("adding example ${EXAMPLE_NAME}") + message(DEBUG "adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message("removing example source file ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message(DEBUG "removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() - if(INSTANCES_ONLY) - set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(EX_TARGETS ${GPU_TARGETS}) - endif() + set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") - message("removing dl example ${source} ") + message(DEBUG "removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any DPP examples if DPP_KERNELS not set + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") + message(DEBUG "removing dpp example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") - message("removing xdl example ${source} ") + message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") - message("removing wmma example ${source} ") + message(DEBUG "removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any microscaling examples if gfx950 target is not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + message(DEBUG "removing microscaling example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any FP8 examples if CK_ENABLE_FP8 not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") - message("removing fp8 example ${source} ") + message(DEBUG "removing fp8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any BF8 examples if CK_ENABLE_BF8 not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8") - message("removing bf8 example ${source} ") + message(DEBUG "removing bf8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 + foreach(source IN LISTS FILE_NAME) + if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") + if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)") + message(DEBUG "Skipping ${source} example for current target") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) + elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 + message(DEBUG "trimming targets for ${FILE_NAME}") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -103,7 +135,14 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() - #message("add_example returns ${result}") + message(DEBUG "add_example returns ${result}") + if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) + set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST") + add_dependencies(smoke ${EXAMPLE_NAME}) + elseif(result EQUAL 0 AND "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) + set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "REGRESSION_TEST") + add_dependencies(regression ${EXAMPLE_NAME}) + endif() set(result ${result} PARENT_SCOPE) endfunction(add_example_executable EXAMPLE_NAME) @@ -114,7 +153,7 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME) endfunction(add_example_dependencies EXAMPLE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) - message("adding example ${EXAMPLE_NAME}") + message(DEBUG "adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS FILE_NAME) @@ -141,44 +180,41 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) set(test 1) endif() if(test EQUAL 1) - message("removing example ${source} ") + message(DEBUG "removing example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() endif() - if(INSTANCES_ONLY) - set(EX_TARGETS ${DEFAULT_GPU_TARGETS}) - else() - set(EX_TARGETS ${GPU_TARGETS}) - endif() + set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") - message("removing dl example ${source} ") + message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any XDL examples if gfx9 targets are not on the list foreach(source IN LISTS FILE_NAME) if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") - message("removing xdl example ${source} ") + message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #Do not build any WMMA examples if gfx11 targets are not on the list foreach(source IN LISTS FILE_NAME) if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") - message("removing wmma example ${source} ") + message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -188,10 +224,18 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() - #message("add_example returns ${result}") + + message(DEBUG "add_example returns ${result}") set(result ${result} PARENT_SCOPE) + endfunction(add_example_executable_no_testing EXAMPLE_NAME) +function(example_compile_options EXAMPLE_NAME) + if(TARGET ${EXAMPLE_NAME}) + target_compile_options(${EXAMPLE_NAME} ${ARGN}) + endif() +endfunction(example_compile_options) + # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) diff --git a/example/README.md b/example/README.md new file mode 100644 index 0000000000..43b3419f80 --- /dev/null +++ b/example/README.md @@ -0,0 +1,2 @@ +[Back to the main page](../README.md) +# Composable Kernel examples \ No newline at end of file diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index f82207b920..1b004ec100 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,7 +1,7 @@ # validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv") +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING - "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") + "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(FMHA_FWD_ENABLE_APIS STREQUAL "all") set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) endif() @@ -17,17 +17,44 @@ if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) list(APPEND FMHA_FWD_ENABLE_APIS "fwd") endif() +file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py +) +# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change +set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") + string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") -# generate a list of kernels, but not actually emit files at config sta -execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt +set(FMHA_FWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FMHA_FWD_APIS} + # --filter fmha_fwd... +) +set(FMHA_BWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api bwd + --receipt 3 + # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) +# generate a list of kernels, but not actually emit files at config sta execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3 + COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + RESULT_VARIABLE ret ) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") +endif() + +execute_process( + COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") +endif() # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory # as current cmake list, otherwise will not figure out the dependency properly @@ -36,20 +63,22 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} ) add_custom_command( OUTPUT ${FMHA_BWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3 + COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -message("adding example ${EXAMPLE_FMHA_FWD}") +message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) @@ -57,7 +86,7 @@ target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -message("adding example ${EXAMPLE_FMHA_BWD}") +message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) @@ -65,7 +94,7 @@ target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) - set(FMHA_FWD_FAST_EXP2 true) + set(FMHA_FWD_FAST_EXP2 true) endif() set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) @@ -74,9 +103,9 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) @@ -94,6 +123,18 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) endif() +# conditionally enable call to the pagedkv_prefill API in fmha_fwd example +if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) +endif() + +# conditionally specify the use of OCP_FP8 +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + # Allow comparing floating points directly in order to check sentinel values list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 0bb5408772..72109a660b 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -6,7 +6,8 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ make tile_example_fmha_fwd -j ``` This will result in an executable `build/bin/tile_example_fmha_fwd` @@ -14,8 +15,7 @@ This will result in an executable `build/bin/tile_example_fmha_fwd` ## kernel The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. -There are 3 template parameters for this kernel template. -* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +There are 2 template parameters for this kernel template. * `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). * `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. @@ -23,7 +23,7 @@ There are 3 template parameters for this kernel template. To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. ## executable -`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all supported args. Below is an example of the output (may subject to change) +`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change) ``` args: -v weather do CPU validation or not (default:1) @@ -31,47 +31,53 @@ args: -b batch size (default:2) -h num of head, for q (default:8) -h_k num of head, for k/v, -1 means equal to h (default:-1) - if not equal to h, then this is GQA/MQA case + if not equal to h, then this is GQA/MQA case -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) - total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary - also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) - -s_k seqlen_k, -1 means equal to s (default:-1) + total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) + -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) - note when squant=1, this value will be modified by range_q/k + note when squant=1, this value will be modified by range_q/k -range_q per-tensor quantization range of q. used if squant=1. (default:16) -range_k per-tensor quantization range of k. used if squant=1. (default:16) -range_v per-tensor quantization range of v. used if squant=1. (default:16) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) -squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto) - 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. - calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o + 0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O. + calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o -iperm permute input (default:1) - if true, will be b*h*s*d, else b*s*h*d + if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) -bias n or 0, no bias (default:n) - e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s - a(libi) or 2, alibi with 1*h. a:1, b*h + e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s + a(libi) or 2, alibi with 1*h. a:1, b*h -prec data type. fp16/bf16/fp8/bf8 (default:fp16) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) - 't', top-left causal mask, 'b', bottom-r causal mask - 't:l,r', top-left sliding window attn(swa) with FA style left right size - 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size - 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa - 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa - 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) + 't', top-left causal mask, 'b', bottom-r causal mask + 't:l,r', top-left sliding window attn(swa) with FA style left right size + 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size + 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa + 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa + 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -lse 0 not store lse, 1 store lse (default:0) -kname if set to 1 will print kernel name (default:0) -init init method. ui, uniform random int, ni, normalized random int (default:uf) - uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization + uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) + -drop_seed seed for random number generator (default:1) +-drop_offset offset for random number generator (default:0) + -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) + -num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) ``` -Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. +Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with + batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case ## support features Currently we are still in rapid development stage, so more features/optimizations will be coming soon. @@ -121,6 +127,6 @@ Note FA use bottom-right by default to express swa case, here we require you exp TBD ## FP8 experimental support -As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+. +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later. diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 66691356ab..9e15a822ef 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -2,10 +2,17 @@ # Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation -DTYPE_MAP = { - "fp16": "ck_tile::fp16_t", - "bf16": "ck_tile::bf16_t", - "fp8" : "ck_tile::fp8_t" +FWD_DTYPE_MAP = { + "fp16" : "FmhaFwdFp16", + "bf16" : "FmhaFwdBf16", + "fp8" : "FmhaFwdFp8", + "fp8fp16": "FmhaFwdFp8Fp16", + "fp8bf16": "FmhaFwdFp8Bf16" +} + +BWD_DTYPE_MAP = { + "fp16": "FmhaBwdFp16", + "bf16": "FmhaBwdBf16" } MASK_IMPL = { @@ -107,11 +114,15 @@ LAYOUT_MAP = { PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", + "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", } PIPELINE_ENUM_MAP = { "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py new file mode 100644 index 0000000000..ffb6d579ed --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -0,0 +1,625 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_BATCH_PREFILL_PIPELINE_MAP = { + "qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_batch_prefill_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp" +FMHA_FWD_API=""" +#include + +namespace {{ +bool get_num_cus(unsigned& num_cu) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cu = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_batch_prefill_(s, a); + }} +""" + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + constraint : CppConstraint + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + +class KernelComponentFactory: + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + @staticmethod + def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + else: + assert False + return pipelines + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == 'fp16' or dtype == 'bf16': + if 128 in result.keys(): + result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + return result + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = CustomFactory.get_hdim_tile_size_dict(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 096394c0c9..89fbcff40c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -60,6 +60,7 @@ using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + false, {F_bias}, {F_dbias}, false, @@ -168,7 +169,7 @@ template 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, @@ -176,7 +177,8 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) ); }} -float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ +template <> +float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ float r = -1; {F_dispatch} return r; @@ -283,7 +285,7 @@ class FmhaBwdApiPool: inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_deterministic=BOOL_MAP[trait.deterministic]) @@ -360,7 +362,7 @@ class FmhaBwdDQDKDVKernel: FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = BWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_tile.F_bm0, F_bn0 = self.F_tile.F_bn0, F_bk0 = self.F_tile.F_bk0, @@ -412,14 +414,26 @@ class FmhaBwdDQDKDVKernel: pn = pad_name() n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' if pn != '' : n += f'_{pn}' + else: n += '_npad' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_dbias == 't' : n += '_dbias' + else: n += '_ndbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + if self.F_dropout != 'no' : n += f'_{self.F_dropout}' + else: n += '_ndropout' + if self.F_deterministic == 't' : n += '_deterministic' + else: n += '_ndeterministic' return n @property @@ -469,7 +483,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> gen = list() api_pool = FmhaBwdApiPool(mask_impl) - for dtype in DTYPE_MAP.keys(): + for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d == None: continue @@ -489,9 +503,10 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] @@ -499,13 +514,45 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dpad == dvpad if not cond: continue - if receipt == 3: + elif receipt == 3: cond = dtype in ['fp16', 'bf16'] cond &= bias in ['no', 'alibi'] cond &= dpad == dvpad cond &= deterministic == "f" if not cond: continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= mode == 'batch' + cond &= deterministic == "f" + if not cond: + continue + # Aiter (mha_bwd) integration + elif receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 400: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= dpad == dvpad + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) @@ -585,7 +632,7 @@ class FmhaBwdOGradDotOKernel: FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = BWD_DTYPE_MAP[self.F_dtype], F_spad = BOOL_MAP[self.F_spad], F_dvpad = BOOL_MAP[self.F_dvpad], F_mode = MODE_MAP[self.F_mode], @@ -602,13 +649,14 @@ class FmhaBwdOGradDotOKernel: pn = pad_name() n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}" if pn != '' : n += f'_{pn}' + else: n += '_npad' return n @property def filename(self) -> str: return self.name + ".cpp" -def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: +def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy # support this in future def get_occupancy(dtype, hdim): @@ -616,7 +664,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: gen = list() - for dtype in DTYPE_MAP.keys(): + for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d == None: continue @@ -627,6 +675,26 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_spad=spad, F_dvpad=dvpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim)) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + # Aiter (mha_bwd) integration + if receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 400: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen @@ -716,7 +784,7 @@ class FmhaBwdConvertQGradKernel: FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = BWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_bm0, F_bn0 = self.F_bn0, F_spad = BOOL_MAP[self.F_spad], @@ -736,14 +804,16 @@ class FmhaBwdConvertQGradKernel: pn = pad_name() n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}" if pn != '' : n += f'_{pn}' - if self.F_deterministic == 't' : n += f'_deterministic' + else: n += '_npad' + if self.F_deterministic == 't' : n += '_deterministic' + else: n += '_ndeterministic' return n @property def filename(self) -> str: return self.name + ".cpp" -def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: +def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]: # TODO: we don't support tuning yet, so pick up one value for pad/occupancy # support this in future def get_occupancy(dtype, hdim): @@ -751,7 +821,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: gen = list() - for dtype in DTYPE_MAP.keys(): + for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) if d == None: continue @@ -762,6 +832,26 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: continue k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + # Aiter (mha_bwd) integration + if receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + if not cond: + continue + # Aiter (mha_varlen_bwd) integration + elif receipt == 400: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen @@ -778,27 +868,37 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - kernels = get_bwd_dot_do_o_blobs() +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (3 - len(filter_list))) + # TODO + assert optdim_list == [-1] + + kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: write_single_bwd_dot_do_o_kernel(kernel, output_dir) - kernels = get_bwd_convert_dq_blobs() + kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) for kernel in kernels: write_single_bwd_convert_dq_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) for kernel in kernels: write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (3 - len(filter_list))) + # TODO + assert optdim_list == [-1] + with file_path.open('a') as f: - kernels = get_bwd_dot_do_o_blobs() + kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - kernels = get_bwd_convert_dq_blobs() + kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 860ee20d3e..06a012d277 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -3,9 +3,10 @@ # generate kernel instances to speed up compilation import copy -from dataclasses import dataclass +from dataclasses import dataclass, field import fnmatch import itertools +import os from pathlib import Path from typing import List, Optional, Tuple @@ -21,41 +22,48 @@ DTYPE_BITS = { "bf8" : 8 } -TILE_PARTITIONER_MAP = { - "shb" : "ck_tile::FmhaFwdTilePartitioner_SHB", - "hbs" : "ck_tile::FmhaFwdTilePartitioner_HBS", +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 } FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd.hpp" """ FMHA_FWD_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; -using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; -using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; -using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, false, {F_lse}, {F_dropout}, {F_squant}, - {F_occupancy}>; + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + using fmha_mask_{F_idx} = {F_mask}; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< @@ -72,6 +80,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::ODataType, fmha_shape_{F_idx}, {F_mode}, + fmha_variant_{F_idx}, fmha_mask_{F_idx}, fmha_trait_{F_idx}>; @@ -84,12 +93,10 @@ using fmha_epilogue_{F_idx} = {F_spad}, {F_dvpad}>>; using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdKernel<{F_tile_partitioner}, - fmha_pipeline_{F_idx}, - fmha_epilogue_{F_idx}>; + ck_tile::FmhaFwdKernel; -using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; #include @@ -108,8 +115,52 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" FMHA_FWD_API=""" +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + {F_dispatch} return r; }} @@ -119,46 +170,62 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ {F_inner_dispatch} }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; return fmha_fwd_(s, a); }} """ +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + @dataclass class FmhaFwdApiTrait: pipeline_tag : str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0blen : int - vlayout : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + constraint : CppConstraint @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' @property def scheck(self) -> str: @@ -166,7 +233,7 @@ class FmhaFwdApiTrait: if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @@ -177,7 +244,7 @@ class FmhaFwdApiTrait: if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' else: assert False @@ -188,9 +255,10 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bk0blen} == 0' + elif self.pipeline_tag in ['qr', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' else: assert False @property @@ -199,25 +267,29 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bk0blen} == 0' + elif self.pipeline_tag in ['qr', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' else: assert False @dataclass class FmhaFwdPipeline: tag : str - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: @@ -232,14 +304,33 @@ class FmhaFwdPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + return n class FmhaFwdApiPool: @@ -251,31 +342,33 @@ class FmhaFwdApiPool: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + self.pool[trait.dtype][hdim].append(copy.copy(trait)) @property def api(self) -> str: per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -285,24 +378,33 @@ class FmhaFwdApiPool: @dataclass class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm : int # number of warps along q seqlen (block warps) - F_rn : int # number of warps along k seqlen(not used) - F_rk : int # number of warps along gemm-k(not used) - F_wm : int # warp size along m (warp size) - F_wn : int # warp size along n - F_wk : int # warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ - f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ - ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @dataclass class FmhaFwdKernel: @@ -314,12 +416,6 @@ class FmhaFwdKernel: F_pipeline : FmhaFwdPipeline mask_impl : str - def get_tp(self) -> str: - if self.F_mode == 'group': - return 'hbs' - else: - return 'shb' - @property def template(self) -> str: kernel_body = str() @@ -327,39 +423,46 @@ class FmhaFwdKernel: FMHA_FWD_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_tile.F_bm0, F_bn0 = self.F_tile.F_bn0, F_bk0 = self.F_tile.F_bk0, F_bn1 = self.F_tile.F_bn1, F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], - F_tile_partitioner = TILE_PARTITIONER_MAP[self.get_tp()]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) @property def name(self) -> str: # TODO: we don't encode idx here - return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_{self.get_tp()}_" + \ + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ self.F_tile.name + '_' + self.F_pipeline.name @property @@ -377,9 +480,10 @@ class FmhaFwdKernel: bk0=self.F_tile.F_bk0, bn1=self.F_tile.F_bn1, bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, + bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, @@ -387,31 +491,39 @@ class FmhaFwdKernel: spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) - } - else: - return None +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ### (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ### (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + ### (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } + else: + return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! @@ -419,53 +531,76 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - if hdim == 256: + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + if hdim == 256 and hdim_v == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None else: assert False return pipelines +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == 'fp16' or dtype == 'bf16': + if (128, 128) in result.keys(): + result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + return result + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) - for dtype in DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) + factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory + + for dtype in FWD_DTYPE_MAP.keys(): + d = factory.get_hdim_tile_size_dict(dtype) if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue + if (hdim, hdim_v) == (192, 128) or hdim == 160: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue k = FmhaFwdKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -473,16 +608,56 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue - if receipt == 2: + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' if not cond: continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) @@ -494,15 +669,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index cfd1d01c91..517e84f380 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -46,9 +46,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline< fmha_pipeline_problem_{F_idx}>; -using fmha_kernel_{F_idx} = - ck_tile::FmhaFwdAppendKVKernel, - fmha_pipeline_{F_idx}>; +using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel; using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>; @@ -181,9 +179,9 @@ class FmhaFwdAppendKVApiPool: inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) @@ -216,7 +214,7 @@ class FmhaFwdAppendKVKernel: FMHA_FWD_APPENDKV_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_bs = self.F_tile.F_bs, F_bsk = self.F_tile.F_bsk, F_bd = self.F_tile.F_bd, @@ -301,6 +299,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> elif dtype in ['fp8', 'bf8']: # rope/paged-kv is not supported pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None else: assert False return pipelines @@ -308,7 +309,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> gen = list() api_pool = FmhaFwdAppendKVApiPool(mask_impl) - for dtype in DTYPE_MAP.keys(): + for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) if d == None: continue @@ -322,14 +323,21 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + # 2 - Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16, bf16'] + cond &= pipeline.F_vlayout == 'row' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -341,15 +349,17 @@ def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> Non def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + assert optdim_list == [-1] api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + assert optdim_list == [-1] with file_path.open('a') as f: _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index ba826c8fb3..edc1532a05 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -29,39 +29,50 @@ DTYPE_BITS = { "bf8" : 8 } +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + # 160: 160, + 256: 256 +} + FMHA_FWD_SPLITKV_PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", } FMHA_FWD_SPLITKV_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; namespace {{ -template -struct kernel_runner {{ -using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; -using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; -using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; +template +struct instance {{ +using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; using fmha_shape = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, {F_vlayout}>; using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, - false, + /*kHasBiasGrad=*/false, {F_lse}, {F_squant}, {F_pagedkv}, kHasUnevenSplits, + kMergeNumHeadGroupsSeqLenQ, {F_occupancy}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< @@ -77,21 +88,22 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::OaccDataType, fmha_shape, {F_mode}, + fmha_variant_{F_idx}, fmha_mask_{F_idx}, fmha_trait>; using fmha_pipeline = {F_pipeline}< fmha_pipeline_problem>; +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue using fmha_epilogue = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, - {F_spad}, {F_dvpad}>>; + false, false>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVKernel, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVKernel; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ @@ -104,34 +116,56 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }}; }} -using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, +using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wtautological-compare" + +namespace {{ +template +void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ + if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS + && (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask> + || std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{ + if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{ + instance::run(s, a); + }} else {{ + instance::run(s, a); + }} + }} else {{ + instance::run(s, a); + }} +}} +}} // anonymous namespace + +#pragma clang diagnostic pop + template<> void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ if constexpr({F_mode} == false) {{ // batch mode // we don't check every seqlen_k values for kvcache if (a.seqlen_k_ptr != nullptr) {{ - kernel_runner::run(s, a); + run_instance(s, a); // make sure F_bn0 is divisible by F_bk1 }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - kernel_runner::run(s, a); + run_instance(s, a); }} else {{ - kernel_runner::run(s, a); + run_instance(s, a); }} }} else {{ - kernel_runner::run(s, a); + run_instance(s, a); }} }} template<> std::string fmha_fwd_splitkv_get_name_() {{ - using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + using k_ = instance::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ @@ -141,7 +175,7 @@ using fmha_dtype_{F_idx} = {F_dtype}; namespace {{ template -struct kernel_runner {{ +struct instance {{ using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad}, {F_dvpad}, {F_lse}, @@ -154,23 +188,22 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< typename FmhaFwdTypeConfig::OaccDataType, typename FmhaFwdTypeConfig::ODataType, {F_hdim}, - {F_bm0}, - {F_bn1}, {F_mode}, + {F_bn1}, fmha_trait>; using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline< fmha_pipeline_problem>; +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue using fmha_epilogue = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, - {F_spad}, {F_dvpad}>>; + false, false>>; using fmha_kernel = - ck_tile::FmhaFwdSplitKVCombineKernel, - fmha_pipeline, - fmha_epilogue>; + ck_tile::FmhaFwdSplitKVCombineKernel; static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ @@ -183,7 +216,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }}; }} -using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn1}, +using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; #include @@ -191,21 +224,23 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m template<> void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if (a.num_splits <= 16) {{ - kernel_runner<4>::run(s, a); + if (a.num_splits <= 8) {{ + instance<3>::run(s, a); + }} else if (a.num_splits <= 16) {{ + instance<4>::run(s, a); }} else if (a.num_splits <= 32) {{ - kernel_runner<5>::run(s, a); + instance<5>::run(s, a); }} else if (a.num_splits <= 64) {{ - kernel_runner<6>::run(s, a); + instance<6>::run(s, a); }} else if (a.num_splits <= 128) {{ - kernel_runner<7>::run(s, a); + instance<7>::run(s, a); }} }} template<> std::string fmha_fwd_splitkv_combine_get_name_() {{ - using k_ = kernel_runner<6>::fmha_kernel; /// FIXME: choose real kernel type + using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ @@ -220,11 +255,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a if(s.log_level_ > 0) std::cout << ", " << fmha_fwd_splitkv_get_name_() - << ", " << fmha_fwd_splitkv_combine_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} ); }} @@ -236,12 +271,31 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - return fmha_fwd_splitkv_(s, a); + // get combine kernel tile sizes + using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; + constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; + + // make sure we can reuse the padding flags in combine kernels + static_assert({F_bm0} % kM0 == 0); + static_assert({F_bn1} % 32 == 0); + + if (t.has_lse) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ + return -1; + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} }} """ @@ -257,9 +311,10 @@ class FmhaFwdSplitKVApiTrait: bk0 : int # tile size along qk gemm unroll bn1 : int # tile size along v head_dim bk1 : int # tile size along kv gemm unroll - bk0blen : int + bk0max : int vlayout : str mask : str + logits : str bias : str # lse : str # squant : str # @@ -267,12 +322,12 @@ class FmhaFwdSplitKVApiTrait: skpad : str dpad : str dvpad : str - pagedkv : str + pagedkv : str @property def name(self) -> str: - return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.dvpad}-{self.pagedkv}' @property @@ -281,7 +336,7 @@ class FmhaFwdSplitKVApiTrait: if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @@ -292,7 +347,7 @@ class FmhaFwdSplitKVApiTrait: if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' else: assert False @@ -303,9 +358,10 @@ class FmhaFwdSplitKVApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: - if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_q % {self.bk0blen} == 0' + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' else: assert False @property @@ -314,9 +370,10 @@ class FmhaFwdSplitKVApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: - if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.hdim_v % {self.bk0blen} == 0' + elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' else: assert False @dataclass @@ -328,6 +385,7 @@ class FmhaFwdSplitKVPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_logits : str # t/f F_bias : str # true/false F_lse : str # F_squant : str # @@ -347,14 +405,29 @@ class FmhaFwdSplitKVPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' else: if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' return n @dataclass @@ -377,8 +450,13 @@ class FmhaFwdSplitKVCombinePipeline: pn = pad_name() n = f'{self.tag}' if pn != '' : n += f'_{pn}' + else: n += '_npad' + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' return n class FmhaFwdSplitKVApiPool: @@ -406,15 +484,15 @@ class FmhaFwdSplitKVApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], + F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, - F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -424,12 +502,11 @@ class FmhaFwdSplitKVApiPool: @dataclass class FmhaFwdSplitKVCombineTileSize: - F_bm0 : int # tile size along q seqlen F_bn1 : int # tile size along v head_dim F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy @property def name(self) -> str: - return f"b{self.F_bm0}x{self.F_bn1}" +\ + return f"b{self.F_bn1}" +\ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") @dataclass @@ -449,24 +526,31 @@ class FmhaFwdSplitKVKernel: FMHA_FWD_SPLITKV_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], + F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_bm0 = self.F_tile.F_bm0, F_bn0 = self.F_tile.F_bn0, F_bk0 = self.F_tile.F_bk0, F_bn1 = self.F_tile.F_bn1, F_bk1 = self.F_tile.F_bk1, - F_bk0blen = self.F_tile.F_bk0blen, - F_rm = self.F_tile.F_rm, - F_rn = self.F_tile.F_rn, - F_rk = self.F_tile.F_rk, - F_wm = self.F_tile.F_wm, - F_wn = self.F_tile.F_wn, - F_wk = self.F_tile.F_wk, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], - F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], @@ -498,8 +582,9 @@ class FmhaFwdSplitKVKernel: bk0=self.F_tile.F_bk0, bn1=self.F_tile.F_bn1, bk1=self.F_tile.F_bk1, - bk0blen=self.F_tile.F_bk0blen, + bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, + logits=self.F_pipeline.F_logits, mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, @@ -526,8 +611,7 @@ class FmhaFwdSplitKVCombineKernel: FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( F_idx = self.F_idx, F_hdim = self.F_hdim, - F_dtype = DTYPE_MAP[self.F_dtype], - F_bm0 = self.F_tile.F_bm0, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], F_bn1 = self.F_tile.F_bn1, F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], @@ -551,16 +635,18 @@ class FmhaFwdSplitKVCombineKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -568,16 +654,18 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdSplitKVCombineTileSize(64, 32, -1), - '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), - '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), - '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + '32' : FmhaFwdSplitKVCombineTileSize(32, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + ### '160' : FmhaFwdSplitKVCombineTileSize(32, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdSplitKVCombineTileSize(64, 64, -1), - '128' : FmhaFwdSplitKVCombineTileSize(64, 128, -1), - '256' : FmhaFwdSplitKVCombineTileSize(64, 256, -1), + '64' : FmhaFwdSplitKVCombineTileSize(32, -1), + '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } else: return None @@ -596,27 +684,35 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - # TODO: use async pipeline when compiler is more stable - if hdim == 256 or hdim in [32, 64, 128]: + for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + # TODO: use async pipeline when compiler is more stable + if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - # no need lse/paged-kv kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask)) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None else: assert False return pipelines @@ -624,7 +720,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> gen = list() api_pool = FmhaFwdSplitKVApiPool(mask_impl) - for dtype in DTYPE_MAP.keys(): + for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_tile_dict_from_dtype(dtype) if d == None: continue @@ -637,9 +733,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if pipeline.F_pagedkv == 't': - # we only use batch mode kernels to handle (paged-) kvcache problems - continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue k = Kernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -647,9 +743,10 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' @@ -657,6 +754,30 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16, bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -687,7 +808,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis gen = list() - for dtype in DTYPE_MAP.keys(): + for dtype in FWD_DTYPE_MAP.keys(): d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) if d == None: continue @@ -706,9 +827,20 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis F_mode=mode, F_tile=tile, F_pipeline=pipeline) - if kernel_filter != None: + if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + # Aiter(mha_varlen_fwd) integration + if receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + if not cond: + continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen @@ -720,21 +852,29 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME file_path.write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (2 - len(filter_list))) + assert optdim_list == [-1] + + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: + filter_list = filter_list.split('@') + filter_list.extend([''] * (2 - len(filter_list))) + assert optdim_list == [-1] + with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(kernel_filter, receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py new file mode 100644 index 0000000000..650ebaf80e --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_FWD_PAGEDKV_PIPELINE_MAP = { + "qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, //lse + {F_pagedkv}, //pagedkv + {F_squant}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdPagedKVKernel; + +using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + +#include + +template<> +float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + return fmha_fwd_pagedkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + pagedkv : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_pagedkv : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' + + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + pagedkv=self.F_pipeline.F_pagedkv, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + elif dtype in ['fp8', 'bf8']: + # TODO + None + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + # if pipeline.F_pagedkv == 'f': + # continue + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' : + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index efae4e284a..b6de5ea621 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_bwd.hpp" #include "ck_tile/host.hpp" @@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[]) .insert("p_drop", "0", "0~1 probability of dropout") .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") @@ -98,15 +101,28 @@ auto create_args(int argc, char* argv[]) } // different threshold for different dtype -template -auto get_elimit(int /*init_method*/) +template +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) { double rtol = 1e-2; double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } -template +template <> +auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) +{ + double rtol = 1e-2; + double atol = 1e-2; + if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN + { + rtol = 3.2e-2; + atol = 3.2e-2; + } + return ck_tile::make_tuple(rtol, atol); +} + +template bool run(const ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); @@ -145,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser) float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + if(use_dbias && bias.type != bias_enum::elementwise_bias) { std::cerr << "dbias only exists when bias type is elementwise" << std::endl; @@ -191,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); - using TypeConfig = FmhaBwdTypeConfig; + using TypeConfig = FmhaBwdTypeConfig; using QDataType = typename TypeConfig::QDataType; using KDataType = typename TypeConfig::KDataType; @@ -337,7 +355,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); - assert(slopes.size() == nhead); + assert(slopes.size() == static_cast(nhead)); if(bias.rank_info == 0) { // alibi in 1*h @@ -368,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes()); @@ -378,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser) do_buf.ToDevice(do_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqstart_k_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off @@ -459,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser) const ck_tile::index_t split_stride_dq_acc = (shape_batch * nhead * shape_seqlen_q * hdim_q); + const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) { + if(drop_prefs) + { + return std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + return std::make_pair(drop_seed, drop_offset); + } + }(); + return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), @@ -532,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser) static_cast(mask.type), p_drop, p_undrop, - {drop_seed, drop_offset}}; + drop_seed_offset}; }(); float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); @@ -722,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(p_drop > 0) { - p_hp_host_ref.ForEach( - [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + p_dropped_hp_host_ref = p_hp_host_ref; randval_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); }); ck_tile::reference_batched_dropout( p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); - p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); } else { - p_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_hp_host_ref.template CopyAsType(); } // O = P * V @@ -820,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser) } // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) - ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { - AccDataType do_dot_o = 0; - for(int o = 0; o < hdim_v; o++) - { - auto idx_gmo = idx_gmn; - idx_gmo[2] = o; - do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * - ck_tile::type_convert(o_host_refs[wb](idx_gmo)); - } - self(idx_gmn) = ck_tile::type_convert( - p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); - }); + ck_tile::make_ParallelTensorFunctor( + [&](auto i0, auto i1, auto i2) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * + ck_tile::type_convert(o_host_refs[wb](i0, i1, o)); + } + ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert( + p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); + }, + ds_hp_host_ref.mDesc.get_lengths()[0], + ds_hp_host_ref.mDesc.get_lengths()[1], + ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); if(use_dbias) { - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + dbias_host_ref = ds_hp_host_ref.template CopyAsType(); } - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + ds_lp_host_ref = ds_hp_host_ref.template CopyAsType(); // dV = P_drop^T@dO^T // dV = P^T@dO^T w/o dropout @@ -899,7 +926,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } // clang-format on - auto [rtol, atol] = get_elimit(init_method); + auto [rtol, atol] = get_elimit(hdim_q, hdim_v); bool dq_cur_pass = ck_tile::check_err(dq_host_result, dq_host_ref, std::string("Error: QGrad Incorrect results!"), @@ -952,11 +979,11 @@ int main(int argc, char* argv[]) const std::string data_type = arg_parser.get_str("prec"); if(data_type == "fp16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } else if(data_type == "bf16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index aea42515dc..9179dbd9be 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -9,13 +9,24 @@ #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" #include "bias.hpp" + #include +#include +#include + +struct FmhaBwdFp16 +{ +}; + +struct FmhaBwdBf16 +{ +}; template struct FmhaBwdTypeConfig; template <> -struct FmhaBwdTypeConfig +struct FmhaBwdTypeConfig { using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; @@ -35,7 +46,7 @@ struct FmhaBwdTypeConfig }; template <> -struct FmhaBwdTypeConfig +struct FmhaBwdTypeConfig { using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; @@ -135,7 +146,8 @@ struct fmha_bwd_args ck_tile::index_t mask_type; float p_drop; float p_undrop; - std::tuple drop_seed_offset; + std::variant, std::pair> + drop_seed_offset; }; template @@ -146,113 +158,113 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { - return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - args.dq_acc_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - args.stride_dq_acc, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - args.nhead_stride_dq_acc, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); } else { // create batch mode kernel arguments - return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.do_ptr, - args.d_ptr, - args.rand_val_ptr, - args.dk_ptr, - args.dv_ptr, - args.dbias_ptr, - args.dq_acc_ptr, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.scale, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_randval, - args.stride_do, - args.stride_dq_acc, - args.stride_dk, - args.stride_dv, - args.stride_dbias, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_randval, - args.nhead_stride_do, - args.nhead_stride_lsed, - args.nhead_stride_dq_acc, - args.nhead_stride_dk, - args.nhead_stride_dv, - args.nhead_stride_dbias, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_randval, - args.batch_stride_do, - args.batch_stride_lsed, - args.batch_stride_dq_acc, - args.batch_stride_dk, - args.batch_stride_dv, - args.batch_stride_dbias, - args.split_stride_dq_acc, - args.window_size_left, - args.window_size_right, - args.mask_type, - args.p_drop, - args.drop_seed_offset); + return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.do_ptr, + args.d_ptr, + args.rand_val_ptr, + args.dk_ptr, + args.dv_ptr, + args.dbias_ptr, + args.dq_acc_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_do, + args.stride_dq_acc, + args.stride_dk, + args.stride_dv, + args.stride_dbias, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_do, + args.nhead_stride_lsed, + args.nhead_stride_dq_acc, + args.nhead_stride_dk, + args.nhead_stride_dv, + args.nhead_stride_dbias, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_do, + args.batch_stride_lsed, + args.batch_stride_dq_acc, + args.batch_stride_dk, + args.batch_stride_dv, + args.batch_stride_dbias, + args.split_stride_dq_acc, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.drop_seed_offset); } }(); @@ -440,4 +452,5 @@ struct fmha_bwd_traits bool is_deterministic; // TODO: padding check is inside this api }; +template float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp old mode 100644 new mode 100755 index 723546a452..e9403f4698 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" +#include "ck_tile/ref/naive_attention.hpp" #include "mask.hpp" #include "rotary.hpp" #include "utils.hpp" @@ -10,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -41,7 +43,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "weather do CPU validation or not") + arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)") .insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("b", "2", "batch size") .insert("h", "8", "num of head, for q") @@ -62,7 +64,7 @@ auto create_args(int argc, char* argv[]) "-1 to choose s_knew in [1, s] randomly.") .insert("s_kpad", "-1", - "seqlen_k stride between 2 tokens, currently used in group-mode only\n" + "seqlen_k stride between 2 batches, currently used in group-mode only\n" "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" "along seqlen, instead of packed. same as xformer kv_padding") .insert("d", "128", "head dim for q, k") @@ -71,6 +73,7 @@ auto create_args(int argc, char* argv[]) "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" "note when squant=1, this value will be modified by range_q/k") + .insert("logits_soft_cap", "0", "attention logits soft capping value.") .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") @@ -122,6 +125,9 @@ auto create_args(int argc, char* argv[]) .insert("p_drop", "0", "0~1 probability of dropout") .insert("drop_seed", "1", "seed for random number generator") .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert( "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") @@ -139,7 +145,7 @@ auto create_args(int argc, char* argv[]) } // different threshold for different dtype -template +template auto get_elimit(std::string /*init_method*/) { double rtol = 1e-3; @@ -148,7 +154,7 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string /*init_method*/) +auto get_elimit(std::string /*init_method*/) { double rtol = 1e-2; double atol = 1e-2; @@ -156,7 +162,7 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string init_method) { if(init_method == "ui" || init_method == "ni") { @@ -172,50 +178,30 @@ auto get_elimit(std::string init_method) } } -int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if(batch_nhead_mblocks >= 0.8f * num_SMs) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + max_splits = std::min({max_splits, num_SMs}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, - // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks - // (i.e. it's 11 splits anyway). - // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || - ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); - }; for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) { - efficiency.push_back(0.f); - } - else - { - float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if(eff > max_efficiency) - { - max_efficiency = eff; - } - efficiency.push_back(eff); + max_efficiency = eff; } + efficiency.push_back(eff); } for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) - { - continue; - } if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); @@ -228,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int override_num_splits_if_necessary( int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) { + (void)hdim_v; int device; auto status = hipGetDevice(&device); if(status != hipSuccess) @@ -244,21 +231,19 @@ int override_num_splits_if_necessary( // tile size should match the generate.py const int kM0 = 64; - const int kN1 = hdim_v; const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); - const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); if(num_splits < 1 && p_drop == 0.0f) { return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); } return num_splits; } -template +template bool run(const ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); @@ -291,7 +276,8 @@ bool run(const ck_tile::ArgParser& arg_parser) #if !CK_TILE_FMHA_FWD_APPENDKV_API if(seqlen_knew != 0) { - std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl; + std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option" + << std::endl; seqlen_knew = 0; } #endif @@ -301,8 +287,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); - if constexpr(!(std::is_same_v || - std::is_same_v)) + if constexpr(!(std::is_same_v || + std::is_same_v)) { if(0 < rotary_dim) { @@ -318,6 +304,13 @@ bool run(const ck_tile::ArgParser& arg_parser) rotary_dim = 0; } #endif + // to use fmha_fwd_appendkv(), make sure it's in batch mode + const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); + if(need_append_kvcache && mode == mode_enum::group) + { + std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl; + mode = mode_enum::batch; + } if(!(rotary_dim <= hdim_q)) { std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; @@ -330,7 +323,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ + CK_TILE_FMHA_FWD_PAGEDKV_API)) if(0 < page_block_size) { std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" @@ -346,29 +340,33 @@ bool run(const ck_tile::ArgParser& arg_parser) } bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API) if(use_cache_batch_idx) { std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" << std::endl; use_cache_batch_idx = false; } +#else + if(use_cache_batch_idx) + { + if(0 < page_block_size) + { + std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + else if(mode == mode_enum::group) + { + std::cerr << "group mode will not use cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + } #endif - if(0 < page_block_size && use_cache_batch_idx) - { - std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " - "'cache_batch_idx' option" - << std::endl; - use_cache_batch_idx = false; - } - // the input tensor layout for kvcache is same as batch mode - const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); - if(use_kvcache && mode != mode_enum::batch) - { - std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl; - mode = mode_enum::batch; - } auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, @@ -377,7 +375,7 @@ bool run(const ck_tile::ArgParser& arg_parser) arg_parser.get_str("s_k"), arg_parser.get_str("s_kpad"), /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, - use_kvcache); + need_append_kvcache); // compute kvcache seqlen_k (before appending knew/vnew) auto cache_seqlen_ks = seqlen_ks; std::transform(cache_seqlen_ks.begin(), @@ -400,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + const float logits_soft_cap = arg_parser.get_float("logits_soft_cap"); + std::string squant_str = arg_parser.get_str("squant"); bool squant = [&]() { if(squant_str == "auto") @@ -413,25 +413,6 @@ bool run(const ck_tile::ArgParser& arg_parser) return atoi(squant_str.c_str()) != 0 ? true : false; }(); - float range_q = arg_parser.get_float("range_q"); - float range_k = arg_parser.get_float("range_k"); - float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); - float range_o = arg_parser.get_float("range_o"); - - float dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - - float scale_p = 1.f; - float scale_o = 1.f; - - if(squant) - { - scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max); - scale_p = dtype_max / range_p; - // scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)] - scale_o = range_p * range_v / range_o / dtype_max; - } - std::string vlayout = arg_parser.get_str("vlayout"); bool lse = arg_parser.get_bool("lse"); @@ -442,6 +423,8 @@ bool run(const ck_tile::ArgParser& arg_parser) float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + if(p_drop < 0.0f || p_drop > 1.0f) { std::cerr << "The value of p_drop should be 0~1" << std::endl; @@ -449,7 +432,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } bool s_randval = false; - if(p_drop > 0.0f && do_validation) + if(p_drop > 0.0f && do_validation != 0) { s_randval = true; } @@ -482,7 +465,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); - using TypeConfig = FmhaFwdTypeConfig; + using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; using KDataType = typename TypeConfig::KDataType; @@ -496,6 +479,28 @@ bool run(const ck_tile::ArgParser& arg_parser) using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; + float range_q = arg_parser.get_float("range_q"); + float range_k = arg_parser.get_float("range_k"); + float range_v = arg_parser.get_float("range_v"); + float range_p = arg_parser.get_float("range_p"); + float range_o = arg_parser.get_float("range_o"); + + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + float scale_p = 1.f; + float scale_o = 1.f; + + if(squant) + { + scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max); + scale_p = p_dtype_max / range_p; + scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max); + } + // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = @@ -517,8 +522,8 @@ bool run(const ck_tile::ArgParser& arg_parser) max_seqlen_k = real_seqlen_k; } - flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + - static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); + flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + + static_cast(2) * mask.get_unmaskarea() * hdim_v); num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + sizeof(KDataType) * real_seqlen_k * hdim_q + @@ -543,7 +548,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cerr << "num_splits greater than 128 is not supported" << std::endl; return false; } -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < p_drop && (1 < num_splits || use_kvcache)) { std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" @@ -552,11 +557,11 @@ bool run(const ck_tile::ArgParser& arg_parser) } #endif - auto get_lengths = [&](bool permute, - ck_tile::index_t b /*batch*/, - ck_tile::index_t h /*nhead*/, - ck_tile::index_t s /*seqlen*/, - ck_tile::index_t d /*hdim*/) { + static const auto get_lengths = [](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { if(permute) return std::array{b, h, s, d}; else @@ -599,7 +604,7 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor bias_host( bias.type == bias_enum::elementwise_bias - ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor alibi_slope_host( @@ -613,12 +618,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor lse_acc_host( 1 < num_splits || use_kvcache - ? std::array{num_splits, shape_batch, nhead, shape_seqlen_q} + ? std::array{shape_batch, nhead, num_splits, shape_seqlen_q} : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( - 1 < num_splits || use_kvcache - ? std::array{num_splits, batch, nhead, max_seqlen_q, hdim_v} - : std::array{1, 1, 1, 1, 1}); + 1 < num_splits || use_kvcache ? std::array{shape_batch, + nhead, + num_splits, + shape_seqlen_q, + hdim_v} + : std::array{1, 1, 1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -689,14 +697,14 @@ bool run(const ck_tile::ArgParser& arg_parser) else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") // suitable for fp8 quantization { - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(knew_host); - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(v_host); - ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(vnew_host); + ck_tile::FillUniformDistribution{-q_dtype_max, q_dtype_max, seed}(q_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, seed}(k_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, seed}(knew_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, seed}(v_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, seed}(vnew_host); // bias_fp8 = qscale_bias * bias_fp32 - float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); + float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); // Assume bias is in [-1.f, 1.f] in original fp32 ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); } @@ -733,12 +741,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqlen_k_buf( - use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || + 0 <= seqlen_kpads[0] + ? seqlen_ks.size() * sizeof(int32_t) + : 0); ck_tile::DeviceMem cache_seqlen_k_buf( need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); @@ -753,10 +765,14 @@ bool run(const ck_tile::ArgParser& arg_parser) seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() : seqstart_k_with_padding_host.data()); - seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); + seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] + ? seqlen_ks.data() + : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); block_table_buf.ToDevice(block_table_host.data()); cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); @@ -787,7 +803,7 @@ bool run(const ck_tile::ArgParser& arg_parser) << (is_rotary_interleaved ? "inter" : "half") << ")"; } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(1 < num_splits) { std::cout << ", num_splits:" << num_splits; @@ -818,6 +834,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else // fmha_fwd_traits or fmha_splitkv_traits { traits.is_group_mode = (mode == mode_enum::group); + traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_lse = lse; @@ -827,6 +844,11 @@ bool run(const ck_tile::ArgParser& arg_parser) { traits.has_dropout = (p_drop > 0.0f); } + else if constexpr(std::is_same_v>) + { + traits.use_pagedkv = use_kvcache; + } } }; @@ -852,9 +874,9 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? seqlen_knew : nhead_k * seqlen_knew; }(); - const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); - const ck_tile::index_t stride_o_acc = hdim_v; + const ck_tile::index_t stride_o_acc = (hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); // setup nhead_stride_* arguments const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); @@ -877,11 +899,11 @@ bool run(const ck_tile::ArgParser& arg_parser) return i_perm ? hdim_v * seqlen_knew : seqlen_knew; }(); const ck_tile::index_t nhead_stride_bias = - (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q; - const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); @@ -893,16 +915,16 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) : (nhead_k * hdim_v * shape_seqlen_k)); const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); // setup split_stride_* arguments (only used in split-kv kernel) - const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); - const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); + const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); args.q_ptr = q_buf.GetDeviceBuffer(); args.k_ptr = k_buf.GetDeviceBuffer(); @@ -964,8 +986,9 @@ bool run(const ck_tile::ArgParser& arg_parser) (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); args.seqstart_k_ptr = (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); - args.seqlen_k_ptr = - (use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) args.max_seqlen_q = max_seqlen_q; @@ -974,6 +997,8 @@ bool run(const ck_tile::ArgParser& arg_parser) args.scale_p = scale_p; args.scale_o = scale_o; + args.logits_soft_cap = logits_soft_cap; + args.stride_bias = (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); args.stride_o = stride_o; @@ -996,9 +1021,17 @@ bool run(const ck_tile::ArgParser& arg_parser) args.nhead_stride_randval = nhead_stride_randval; args.batch_stride_randval = batch_stride_randval; - args.p_drop = p_drop; - args.s_randval = s_randval; - args.drop_seed_offset = std::tie(drop_seed, drop_offset); + args.p_drop = p_drop; + args.s_randval = s_randval; + if(drop_prefs) + { + args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); + } } else if constexpr(std::is_same_v>) { @@ -1009,6 +1042,7 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); args.batch_stride_block_table = batch_stride_block_table; args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration args.cache_batch_idx = (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); @@ -1023,6 +1057,17 @@ bool run(const ck_tile::ArgParser& arg_parser) args.split_stride_lse_acc = split_stride_lse_acc; args.split_stride_o_acc = split_stride_o_acc; } + else if constexpr(std::is_same_v>) + { + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + } } }; @@ -1044,7 +1089,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const float fwd_ave_time = [&] { #if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < num_splits || use_kvcache) + if(1 < num_splits && use_kvcache) { fmha_fwd_splitkv_traits fmha_splitkv_traits; init_traits(fmha_splitkv_traits); @@ -1054,6 +1099,18 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); } +#endif +#if CK_TILE_FMHA_FWD_PAGEDKV_API + if(use_kvcache) + { + fmha_fwd_pagedkv_traits fmha_pagedkv_traits; + init_traits(fmha_pagedkv_traits); + + fmha_fwd_pagedkv_args fmha_pagedkv_args; + init_args(fmha_pagedkv_args); + + return fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + } #endif fmha_fwd_traits fmha_traits; init_traits(fmha_traits); @@ -1080,25 +1137,76 @@ bool run(const ck_tile::ArgParser& arg_parser) << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << " GB/s" << std::flush; - if(!do_validation) + if(do_validation == 0) { std::cout << std::flush << std::endl; return true; } + if(do_validation == 2) + { + // NOTE: use gpu to do validation + ck_tile::naive_attention_fwd_traits naive_t; + naive_t.q_type = data_type; + naive_t.k_type = data_type; + naive_t.v_type = data_type; + naive_t.o_type = data_type; + naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; + naive_t.variation = 0; // TODO? + naive_t.quant_algo = 0; + + ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); + + ck_tile::naive_attention_fwd_args naive_a; + naive_a.q_ptr = q_buf.GetDeviceBuffer(); + naive_a.k_ptr = k_buf.GetDeviceBuffer(); + naive_a.v_ptr = v_buf.GetDeviceBuffer(); + naive_a.o_ptr = o_naive_buf.GetDeviceBuffer(); + naive_a.scale_s = scale_s; + naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer + naive_a.page_table_ptr = + nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn) + naive_a.hdim = hdim_q; + naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different + naive_a.batch_q = batch; + naive_a.batch_kv = batch; + naive_a.batch_ratio_kv = 1; // batch_q / batch_kv + naive_a.seqlen_q = seqlen_qs[0]; + naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field + naive_a.nhead_q = nhead; + naive_a.nhead_kv = nhead_k; + naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv + naive_a.page_size = 0; // if paged, the seqlen-kv for each block + + ck_tile::stream_config naive_s{}; + + naive_attention_fwd(naive_t, naive_a, naive_s); + + auto o_naive_ref = o_naive_buf.ToHost(); + o_buf.FromDevice(o_host.data()); // TODO: ugly + + auto [rtol_, atol_] = get_elimit(init_method); + bool pass_ = ck_tile::check_err( + o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_); + std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl; + return pass_; + } o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); auto p_compute_element_func = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return ck_tile::scales{scale_p}; else return ck_tile::identity{}; }(); auto oacc_element_func = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); else @@ -1148,7 +1256,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); - auto [rotary_cos_slice, rotary_sin_slice] = + auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); ck_tile::reference_batched_rotary_position_embedding( @@ -1158,19 +1266,19 @@ bool run(const ck_tile::ArgParser& arg_parser) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < page_block_size) { if(i_perm) { k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); }); - } else { + } else { k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); }); } } else -#endif +#endif { if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); @@ -1191,7 +1299,7 @@ bool run(const ck_tile::ArgParser& arg_parser) { knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); - auto [rotary_cos_slice, rotary_sin_slice] = + auto [rotary_cos_slice, rotary_sin_slice] = slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); ck_tile::reference_batched_rotary_position_embedding( @@ -1209,23 +1317,23 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < page_block_size) { if(is_v_rowmajor) { if(i_perm) { - v_host_ref.ForEach([&](auto& self, auto i) { - self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); }); } else { - v_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); }); } } - else + else { - if(i_perm) { - v_host_ref.ForEach([&](auto& self, auto i) { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); }); } else { @@ -1282,15 +1390,25 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); + if(0.f < logits_soft_cap) + { + ck_tile::reference_unary_elementwise( + s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { + return ck_tile::type_convert( + logits_soft_cap * + std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); + }); + } + if(bias.type == bias_enum::elementwise_bias) { // elementwise bias ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); // clang-format on // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, @@ -1420,7 +1538,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on - auto [rtol, atol] = get_elimit(init_method); + auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err( o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); pass &= cur_pass; @@ -1477,15 +1595,15 @@ int main(int argc, char* argv[]) const std::string data_type = arg_parser.get_str("prec"); if(data_type == "fp16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } else if(data_type == "bf16") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } else if(data_type == "fp8") { - return run(arg_parser) ? 0 : -2; + return run(arg_parser) ? 0 : -2; } return -3; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 183475064a..81dda692ea 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -13,12 +13,38 @@ #include "rotary.hpp" #include +#include +#include + +struct FmhaFwdFp16 +{ +}; + +struct FmhaFwdBf16 +{ +}; + +struct FmhaFwdFp8 +{ +}; + +struct FmhaFwdBf8 +{ +}; + +struct FmhaFwdFp8Fp16 +{ +}; + +struct FmhaFwdFp8Bf16 +{ +}; template struct FmhaFwdTypeConfig; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { using QDataType = ck_tile::half_t; using KDataType = ck_tile::half_t; @@ -34,7 +60,7 @@ struct FmhaFwdTypeConfig }; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { using QDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t; @@ -50,7 +76,7 @@ struct FmhaFwdTypeConfig }; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { using QDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t; @@ -66,7 +92,7 @@ struct FmhaFwdTypeConfig }; template <> -struct FmhaFwdTypeConfig +struct FmhaFwdTypeConfig { using QDataType = ck_tile::bf8_t; using KDataType = ck_tile::bf8_t; @@ -117,6 +143,8 @@ struct fmha_fwd_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -141,10 +169,93 @@ struct fmha_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; float p_drop; bool s_randval; - std::tuple drop_seed_offset; + + std::variant, std::pair> + drop_seed_offset; +}; + +struct fmha_fwd_pagedkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + // + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; }; struct fmha_fwd_splitkv_args @@ -161,6 +272,8 @@ struct fmha_fwd_splitkv_args void* block_table_ptr; ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. const void* cache_batch_idx; @@ -169,9 +282,21 @@ struct fmha_fwd_splitkv_args // seqlen_k = kargs.seqlen_k // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] - // kvcache mode (use same kernel as batch mode): + // or kargs.seqlen_k_ptr[b] + // + // batch mode (kvcache): // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; @@ -190,6 +315,8 @@ struct fmha_fwd_splitkv_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -247,7 +374,7 @@ struct fmha_fwd_appendkv_args ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr - const void* cache_batch_idx; + const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) ck_tile::index_t stride_q; ck_tile::index_t stride_k; @@ -266,8 +393,196 @@ struct fmha_fwd_appendkv_args ck_tile::index_t batch_stride_vnew; }; +struct fmha_batch_prefill_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] - + // 1) + + // kargs.kv_last_page_lens[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] - + // 1) + + // kargs.kv_last_page_lens[b] + const void* seqstart_q_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + // SGLang-style page table + int32_t num_total_pages; + void* kv_indptr; + void* kv_page_indices; +#if 0 // we assume page_block_size=1 for now + void* kv_last_page_lens; + ck_tile::index_t page_block_size; +#endif + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + +template +auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { @@ -278,7 +593,6 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, - args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqstart_q_ptr, @@ -288,28 +602,31 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, args.stride_bias, - args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, - args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, args.window_size_left, args.window_size_right, args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.min_seqlen_q); } else { // create batch mode kernel arguments @@ -317,49 +634,59 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.k_ptr, args.v_ptr, args.bias_ptr, - args.rand_val_ptr, args.lse_ptr, args.o_ptr, args.seqlen_q, args.seqlen_k, + args.seqlen_k_ptr, args.hdim_q, args.hdim_v, args.nhead_q, args.nhead_q / args.nhead_k, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, args.stride_bias, - args.stride_randval, args.stride_o, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, args.nhead_stride_bias, - args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, args.batch_stride_bias, - args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type, - args.p_drop, - args.s_randval, - args.drop_seed_offset); + args.mask_type); } }(); - dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); - return ck_tile::make_tuple(kargs, grids); + // FmhaKernel::PrintParameters(kargs, args.batch); + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } } template @@ -385,8 +712,13 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_q, args.nhead_q / args.nhead_k, args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -398,10 +730,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_stride_bias, args.nhead_stride_lse_acc, args.nhead_stride_o_acc, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_lse_acc, - args.batch_stride_o_acc, + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache args.split_stride_lse_acc, args.split_stride_o_acc, args.window_size_left, @@ -431,6 +761,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.cache_batch_idx, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -456,8 +787,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) } }(); - dim3 grids = - Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.num_splits); + dim3 grids = Kernel::GridSize( + args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits); return ck_tile::make_tuple(kargs, grids); } @@ -475,7 +806,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.lse_ptr, args.o_ptr, args.batch, - args.max_seqlen_q, args.seqstart_q_ptr, args.hdim_v, args.num_splits, @@ -486,7 +816,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.nhead_stride_o_acc, args.nhead_stride_lse, args.nhead_stride_o, - args.batch_stride_o_acc, args.split_stride_lse_acc, args.split_stride_o_acc); } @@ -497,7 +826,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.lse_ptr, args.o_ptr, args.batch, - args.max_seqlen_q, args.seqlen_q, args.hdim_v, args.num_splits, @@ -567,6 +895,117 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, +#if 0 // we assume page_block_size=1 for now + args.kv_last_page_lens, + args.page_block_size, +#endif + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, +#if 0 // we assume page_block_size=1 for now + args.kv_last_page_lens, + args.page_block_size, +#endif + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template + bool kPadDv_, + bool kSkipMinSeqlenQ_ = false> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -601,6 +1042,7 @@ struct fmha_fwd_traits_ static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; using FmhaMask = ck_tile::remove_cvref_t; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -610,6 +1052,7 @@ struct fmha_fwd_traits_ static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; template @@ -626,6 +1069,58 @@ template +struct fmha_fwd_pagedkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kIsPagedKV = kIsPagedKV_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; +}; + +template +float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args); + +template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -668,7 +1164,6 @@ std::string fmha_fwd_splitkv_get_name_(); template ; static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr ck_tile::index_t kM0 = kM0_; static constexpr ck_tile::index_t kN1 = kN1_; static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; @@ -727,6 +1221,9 @@ struct fmha_fwd_appendkv_traits_ template float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); +template +float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -735,15 +1232,38 @@ struct fmha_fwd_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool has_dropout; bool do_fp8_static_quant; + bool skip_min_seqlen_q = false; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); +struct fmha_fwd_pagedkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + bool has_logits_soft_cap; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse = false; + bool use_pagedkv = true; + bool do_fp8_static_quant = false; + bool skip_min_seqlen_q = false; + // TODO: padding check is inside this api +}; + +float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits&, + fmha_fwd_pagedkv_args&, + const ck_tile::stream_config&); + struct fmha_fwd_splitkv_traits { int hdim_q; @@ -751,6 +1271,7 @@ struct fmha_fwd_splitkv_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; @@ -772,3 +1293,8 @@ struct fmha_fwd_appendkv_traits float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); + +using fmha_batch_prefill_traits = fmha_fwd_traits; +float fmha_batch_prefill(fmha_batch_prefill_traits, + fmha_batch_prefill_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 9b91d36fb2..c611618824 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -17,12 +17,11 @@ class HandlerId(IntEnum): LIST_BLOBS = 0 WRITE_BLOBS = 1 -# inspect all modules under 'codegen.ops' and register API handlers +# inspect all modules under 'codegen.ops' and register API handlers ops = [] for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) - if full_module_name not in sys.modules: - ops.append(importer.find_spec(module_name).loader.load_module(module_name)) + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) unwanted_prefix = 'fmha_' handlers = dict( [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, @@ -30,7 +29,7 @@ handlers = dict( ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -38,18 +37,21 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : output_dir.mkdir(parents=True, exist_ok=True) - for api in api_list: + for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(output_dir, kernel_filter, receipt, mask_impl) + handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) - for api in api_list: + # create an empty file / drop its contents if it exists + open(file_path, "w").close() + + for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - handler(file_path, kernel_filter, receipt, mask_impl) + handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -81,6 +83,7 @@ if __name__ == "__main__": parser.add_argument( "-f", "--filter", + default='', required=False, help="filter out kernels that need to generate, using fnmatch module" ) @@ -100,12 +103,33 @@ if __name__ == "__main__": required=False, help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ " 1: generate more instance to cover all hdim\n" + \ - " 2: Only generate instance for Flash attention integration" + " 2: Only generate instance for Flash attention integration\n" + \ + " 4: Only generate instance for PyTorch integration\n" + \ + " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ + " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ + " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" + ) + + parser.add_argument( + "--optdim", + default='-1', + required=False, + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ + "eg. --optdim=32,64,128,256" ) args = parser.parse_args() api_list = args.direction.split(',') + filter_list = args.filter.split(',') + filter_list.extend([''] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(',')] + + if len(api_list) > 1: + assert optdim_list == [-1] + if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp old mode 100644 new mode 100755 index c77b700b16..b96482f535 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -21,6 +21,8 @@ enum class mask_enum struct mask_info { mask_enum type; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; ck_tile::index_t y, x; ck_tile::index_t left, right; // FA style SWA left/right @@ -42,6 +44,8 @@ struct mask_info ck_tile::index_t x_total = seqlen_k; ck_tile::index_t y_total = seqlen_q; mask_info tmp; + tmp.seqlen_q = seqlen_q; + tmp.seqlen_k = seqlen_k; auto found_0 = str.find(':'); if(found_0 != std::string::npos) { @@ -148,7 +152,22 @@ struct mask_info } return tmp; } - + ck_tile::index_t get_unmaskarea() const + { + if(type == mask_enum::no_mask) + return seqlen_q * seqlen_k; + ck_tile::index_t area = 0; + for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) + { + ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0)); + ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); + if(x_end > x_start) + { + area += (x_end - x_start); + } + } + return area; + } friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) { mi.serialize(os); diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 5dcc6ed42b..b867cd6c07 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -29,14 +29,14 @@ while getopts ":sa" opt; do done run_fp16_bf16_tests() { - local NUM_SPLITS=(1) - local PAGE_BLOCK_SIZE=(0) - local CACHE_BATCH_IDX=(0) + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" if [ $TEST_SPLITKV -eq 1 ] ; then - NUM_SPLITS+=(2 3) - PAGE_BLOCK_SIZE+=(128) - CACHE_BATCH_IDX+=(1) + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" fi for prec in "fp16" "bf16" ; do @@ -47,9 +47,9 @@ run_fp16_bf16_tests() { for lse in 0 1 ; do for bias in "n" "e" "a" ; do for p_drop in 0.0 0.2 ; do - for num_splits in "${NUM_SPLITS[@]}" ; do - for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do - for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS @@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests fi -set +x \ No newline at end of file +set +x diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp index 996032a717..faf3f08437 100644 --- a/example/ck_tile/01_fmha/utils.hpp +++ b/example/ck_tile/01_fmha/utils.hpp @@ -145,7 +145,7 @@ decode_seqlen(mode_enum mode, std::string k_val, std::string k_pad_val, ck_tile::index_t seqlen_k_min = 0, - bool use_kvcache = false, + bool need_append_kvcache = false, std::optional seed = std::nullopt) { #define _S2I_(str_) static_cast(std::atoi((str_).c_str())) @@ -159,7 +159,7 @@ decode_seqlen(mode_enum mode, const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); std::vector seqlen_ks(batch, seqlen_k_max); - if(1 < batch && use_kvcache) + if(1 < batch && need_append_kvcache) { // to keep the original s_k value, we always use seqlen_k_max in first batch randints(std::next(seqlen_ks.begin()), diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index bac5f45cd3..07714f0fe2 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -1,4 +1,44 @@ -# not using add_example_executable() to add this target, since we don't want this to have -# to be included in "make all/install/check" -add_executable(tile_example_layernorm2d_fwd EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) -target_compile_options(tile_example_layernorm2d_fwd PRIVATE -DSAVE_MEAN_INV_STD) \ No newline at end of file +set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd") +set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".") +if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all") + set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS}) +endif() + +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs +) + +set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") + +message(DEBUG "adding example ${EXAMPLE_LAYERNORM2D_FWD}") +add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) +target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) + +set(EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress) + +target_compile_options(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${EXAMPLE_LAYERNORM2D_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 433dad04e6..817f62dae7 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -1,6 +1,42 @@ # Layernorm2D forward -This folder contains example for Layernorm2D forward using ck_tile tile-programming implementation. +This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation. + +# Implementation and feature support + +## welford online algorithm +We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`. + +## mean/variance save +In training case the mean/variance need to store out (TBD, not supported yet) + +## prenorm/postnorm + +![](misc/pnorm.png) + +since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default). + +## smooth-quant/dynamic-quant +we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen) +![](misc/dquant.png) + +``` +# assume output int8, hidden_states is [m, n] shape and in fp16/bf16 +# [m, 1] +per_token_amax, _ = torch.max( + input=torch.abs(hidden_states), + dim=-1, + keepdim=True +) +per_token_scale = per_token_amax.to(dtype=torch.float32) / 127.0 + +# quant hidden_states +hidden_states = (hidden_states / per_token_scale).to(dtype=torch.int8) + +return hidden_states, per_token_scale +# hidden_states now is int8 will feed to next layer as intput +# per_token_scale will be used as dequant factor later layer +``` ## build ``` @@ -15,8 +51,35 @@ This will result in an executable `build/bin/tile_example_layernorm2d_fwd` ``` args: -m m dimension (default:3328) - -n m dimension (default:4096) + -n n dimension (default:4096) + -stride stride per row, if -1 then equal to n (default:-1) -e epsilon (default:1e-5) + -save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0) -v cpu validation or not (default:1) - -prec precision (default:fp16) -``` \ No newline at end of file + -kname print kernel name or not (default:1) + -prec_i input precision (default:fp16) + -prec_o output precision, set auto will be the same as input (default:auto) + -prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto) + -prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto) + -fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0) + -fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0) + -warmup cold iter (default:5) + -repeat hot iter (default:20) + +``` + +## limitations +Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. + +``` +# some case +# standard fp16 layernorm 2d, m=10. n=1024 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 + +# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant, output in int8 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 + +# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8 +./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1 + +``` diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py new file mode 100644 index 0000000000..d77582630a --- /dev/null +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -0,0 +1,730 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Any +import functools +import itertools +import copy +from dataclasses import dataclass + +def get_if_str(idx, total, lase_else = True): + if idx == 0: + return 'if' + elif idx < total - 1: + return 'else if' + else: + if lase_else: + return 'else' + else: + return 'else if' + +XBIAS_ENUM_STR_MAP = [ + 'no', + 'xbias'] # pre-norm add bias + +FUSED_ADD_ENUM_STR_MAP = [ + 'no', + 'pras', # pre-norm + 'pra' ] # post-norm + +FUSED_FUSED_SWEEP_STR_MAP = [ + 'no', + 'dquant' ] + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::fp16_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t', + 'fp8' : 'ck_tile::fp8_t'} + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + +class layernorm_fwd_codegen: + API_TRAITS_DEFINE = """ +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct layernorm2d_fwd_traits_ +{ + using XDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using SmoothScaleDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kFastFDiv = kFastFDiv_; + static constexpr bool kWelford = kWelford_; + static constexpr bool kTwoPass = kTwoPass_; + static constexpr ck_tile::index_t kXbias = kXbias_; + static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; + static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; +}; + +template +using traits_ = layernorm2d_fwd_traits_; +""" + API_COMMON_HEADER = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" +#include +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = layernorm2d_fwd_args; + +{F_traits_define} + +template +float layernorm2d_fwd_(const S& s, A a) +{{ + using XDataType = typename Traits_::XDataType; + using YDataType = typename Traits_::YDataType; + using SmoothScaleDataType = typename Traits_::SmoothScaleDataType; + using YScaleDataType = typename Traits_::YScaleDataType; + using ComputeDataType = typename LayerNormTypeConfig::ComputeDataType; + + using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kXbias), + static_cast(Traits_::kFusedAdd), + static_cast(Traits_::kFusedQuant)>; + using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem< + typename LayerNormTypeConfig::XDataType, + typename LayerNormTypeConfig::XBiasDataType, + typename LayerNormTypeConfig::GammaDataType, + typename LayerNormTypeConfig::BetaDataType, + typename LayerNormTypeConfig::ComputeDataType, + typename LayerNormTypeConfig::YDataType, + typename LayerNormTypeConfig::MeanDataType, + typename LayerNormTypeConfig::InvStdDataType, + typename LayerNormTypeConfig::SmoothScaleDataType, + typename LayerNormTypeConfig::YScaleDataType, + typename Traits_::Shape, + PipelineTraits>; + + using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; + using Default2DEpilogue = ck_tile::Default2DEpilogue; + + static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; + static constexpr bool UseRawStore = sizeof(YDataType) == 4; + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + + using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; + + using Epilogue = std::conditional_t; + + using Kernel = ck_tile::Layernorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); +}} + +""" + + API_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "layernorm2d_fwd.hpp" + +{F_traits_define} + +// Note: this internal API only declare, not define here, otherwise will block `make -j` +template +float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a); + +float layernorm2d_fwd(layernorm2d_fwd_traits t, + layernorm2d_fwd_args a, + const ck_tile::stream_config& s) +{{ + float r = -1; +{F_dispatch} + return r; +}} + +""" + + API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{ +{F_per_n_case} + }} +""" + API_PER_N_CASE=""" {F_if} {F_N_COND} {{ +{F_inner_dispatch} + }} +""" + API_INNER_CASE=""" {F_if} {F_VEC_COND} + r={F_instance_func}(s, a); +""" + + INSTANCE_BASE = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "layernorm2d_fwd_api_common.hpp" + +// clang-format off +// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p xbias add sweep +{F_instance_def} +// clang-format on + +""" + + def __init__(self, working_path, kernel_filter): + self.working_path = working_path + self.kernel_filter = kernel_filter + + class k_xbias_enum(IntEnum): + F_NO_XBIAS = 0 + F_ADD_XBIAS = 1 + + class k_fuesd_add_enum(IntEnum): + F_NO_ADD = 0 + F_PRE_ADD = 1 + F_PRE_ADD_STORE_RESIDUAL = 2 + + class k_fused_sweep_enum(IntEnum): + F_NO_SWEEP = 0 + F_RENORM = 1 + F_DYNAMIC_QUANT = 2 + + @dataclass + class k_traits: + F_kPadN : bool + F_kSaveMeanInvStd : bool + F_kTwoPass : bool + F_kXbias : Any #: layernorm_fwd_codegen.k_bias_enum + F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum + F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum + + @dataclass + class k_shape: + F_BlockTile : List[int] + F_WarpPerBlock : List[int] + F_WarpTile : List[int] + F_Vector_ : List[int] + @property + def F_BlockSize(self) -> int: + return functools.reduce(lambda a, b: a*b, self.F_WarpTile) + + @dataclass + class k_problem: + F_XDataType : str + F_XBiasDataType : str + F_GammaDataType : str + F_BetaDataType : str + F_ComputeDataType : str + F_YDataType : str + F_MeanDataType : str + F_InvStdDataType : str + F_BlockShape : str + F_Traits : Any #k_traits + + @dataclass + class k_pipeline_one_pass: + F_Problem : Any #k_problem + + @dataclass + class k_pipeline_two_pass: + F_Problem : Any #k_problem + + @dataclass + class default_2d_epilogue_problem: + F_AccDataType : str + F_ODataType : str + F_kPadM : bool + F_kPadN : bool + + @dataclass + class default_2d_epilogue: + F_problem : Any + + @dataclass + class k_kernel: + F_pipeline : Any + F_epilogue : Any + + @dataclass + class h_traits: + F_XDataType : str + F_YDataType : str + F_SmoothScaleDataType : str + F_YScaleDataType : str + F_Repeat_M : int + F_Repeat_N : int + F_ThreadPerBlock_M : int + F_ThreadPerBlock_N : int + F_Vector_N : int + F_kPadN : bool + F_kSaveMeanInvStd_ : bool + F_kFastFDiv_ : bool + F_kWelford_ : bool + F_kTwoPass_ : bool + F_kXbias_ : int + F_kFusedAdd : int + F_kFusedQuant : int + + @property + def trait_name(self) ->str: + t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}' + t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' + return t_ + + # string when calling this kernel + @property + def call_name(self) -> str: + return f'layernorm2d_fwd_>' + + # string when define this kernel + @property + def def_name(self) -> str: + return f'template float layernorm2d_fwd_>(const S&, A);' + + # this class hold kernel under same source file + @dataclass + class h_instance: + F_DataTypePair : str + F_N : str + F_xbias : int + F_add : int + F_sweep : int + instance_list : List[Any] # List[h_traits] + + @property + def name(self) -> str: + prec_i, prec_o = self.F_DataTypePair.split(',') + dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}' + nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}' + if self.F_xbias != 0: + nnn = nnn + '_' + XBIAS_ENUM_STR_MAP[self.F_xbias] + if self.F_add != 0: + nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add] + if self.F_sweep != 0: + nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep] + return nnn + + @property + def instance_name(self) ->str: + return self.name + + @property + def content(self) ->str: + instance_defs = '' + for ins in self.instance_list: + instance_defs += ins.def_name + '\n' + return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs) + + @property + def name_api(self) -> str: + return 'layernorm2d_fwd_api' + + @property + def name_common_header(self) -> str: + return 'layernorm2d_fwd_api_common' + + def content_api(self, args) -> str: + # 1 sort based on dtype + t_dtype_dict = dict() + blobs = self.get_blobs(args) + for blob in blobs: + if blob.F_DataTypePair not in t_dtype_dict: + t_dtype_dict[blob.F_DataTypePair] = {} + if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]: + t_dtype_dict[blob.F_DataTypePair][blob.F_N] = [] + t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob) + + d_str = '' + for i_d, dtype_ in enumerate(t_dtype_dict): + blob_per_t = t_dtype_dict[dtype_] + n_str = '' + for i_n, n_ in enumerate(blob_per_t): + blob_per_n = blob_per_t[n_] + inner_str = "" + for i_b, b_ in enumerate(blob_per_n): + # generate single kernel instance file + #vec_str = "" + for i_ins, ins in enumerate(b_.instance_list): + idx_in_n = i_b * len(b_.instance_list) + i_ins + len_in_n = len(blob_per_n) * len(b_.instance_list) + # _if = 'if' if i_ins == 0 else 'else if' + if ins.F_kFusedQuant == 0: + _sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant) + elif ins.F_kFusedQuant == 1: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType) + elif ins.F_kFusedQuant == 2: + _sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format( + f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType) + _cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format( + f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd, + f_sweep_cond = _sweep_cond) + inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False), + F_VEC_COND = _cond, F_instance_func=ins.call_name) + #inner_str = inner_str + vec_str + n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else '' + n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str) + prec_i, prec_o = dtype_.split(',') + d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str) + + api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str) + return api_base + + @property + def content_common_header(self) -> str: + return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE) + + def get_blobs(self, args): + h_traits = layernorm_fwd_codegen.h_traits + h_instance = layernorm_fwd_codegen.h_instance + + dynamic_quant_out_dtype = ['int8', 'fp8'] + # some predefined support range + # (prec_i,prec_o) for simplicity this string will be used as key for dict + scale_list = [('fp32,fp32')] + dtype_list = [('fp16,fp16'), ('bf16,bf16'), + ('fp16,int8'), ('bf16,int8'), + ('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out + types_8bit = ('int8', 'fp8') + types_16bit = ('int16', 'fp16', 'bf16') + #fused_add_list = [0, 1, 2] + #fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant + xbias_list = [0, 1] + fused_add_list = [0, 1] + fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant + # rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} + total_blob = list() + for hs_key in h_trait_dict: + hs = h_trait_dict[hs_key] + current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N + for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list): + prec_i, prec_o = dtype.split(',') + scale_sm, scale_y = scale_type.split(',') + if prec_o in dynamic_quant_out_dtype and fused_quant != 1: + continue # skip non dynamic quant case + if fused_quant == 1 and hs_key == 'big': + continue + current_hs = list() + for chs_ in hs: + h_ = copy.copy(chs_) # copy the base instance out + h_.F_XDataType = prec_i + h_.F_YDataType = prec_o + h_.F_SmoothScaleDataType = scale_sm + h_.F_YScaleDataType = scale_y + h_.F_kXbias = xbias + h_.F_kFusedAdd = fused_add + h_.F_kFusedQuant = fused_quant + # disable welford update for 8bit and 16 bit smallN + if not h_.F_kTwoPass_: + #disable 16 bit when set args disable_16b_welford + if args.disable_16b_welford and prec_i in types_16bit: + h_.F_kWelford_ = False + #disable 8bit by default + elif prec_i in types_8bit or prec_o in types_8bit: + h_.F_kWelford_ = False + #disable 16bit small N + elif prec_i in types_16bit and hs_key == '64': + h_.F_kWelford_ = False + current_hs.append(h_) # + "\n" + #f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_ + current_n_str = 'big' if hs_key == 'big' else current_n + total_blob.append(h_instance(dtype, current_n_str, xbias, fused_add, fused_quant, current_hs)) + return total_blob + + def list_blobs(self, args) -> None: + w_p = Path(self.working_path) + list_p = w_p / 'layernorm2d_fwd_blobs.txt' + blobs = self.get_blobs(args) + with list_p.open('w') as list_f: + # api related file + list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") + list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") + # kernel instance file + for b in blobs: + list_f.write(str(w_p / (b.name + ".cpp")) + "\n") + + def gen_blobs(self, args) -> None: + w_p = Path(self.working_path) + w_str = self.content_api(args) + (w_p / (self.name_api + ".cpp")).write_text(w_str) + (w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header) + blobs = self.get_blobs(args) + for b in blobs: + (w_p / (b.name + ".cpp")).write_text(b.content) + +def list_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args) + + +def gen_blobs(args): + api_list = args.api.split(',') + for api in api_list: + if api == 'fwd': + layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK layernorm kernel", + ) + parser.add_argument( + "-a", + "--api", + default='fwd[all]', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + + # the directory for list_blobs/gen_blobs to write files into + parser.add_argument( + "-w", + "--working_path", + default="./", + required=False, + help="the path where all the blobs are going to be generated" + ) + + # this script have 2 modes + # 1) list_blobs mode, will generate a txt file with all the files going to be generated. + # this is useful in build system like cmake to construct source code dependency, by + # reading the content out of this file + # 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework + # like FA, only need to use this mode + parser.add_argument( + "-l", + "--list_blobs", + action='store_true', + help="list all the kernels to a file, " + ) + + parser.add_argument( + "-g", + "--gen_blobs", + action='store_true', + help="generate all kernels into different tile" + ) + + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-t", + "--traits", + default="all", + required=False, + help="enable/disable some feature. default generate all" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt." + ) + + parser.add_argument( + "--disable_16b_welford", + default=False, + required=False, + help="enable/disable welford for 16bit datatype n > 64" + ) + + args = parser.parse_args() + + # print(f'{args.list_blobs}-{args.gen_blobs}') + if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)): + print('gen_blobs/list_blobs must specify only one option') + sys.exit() + + p = Path(args.working_path) + if not p.exists(): + p.mkdir() + + if args.list_blobs: + list_blobs(args) + else: + gen_blobs(args) diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 9cbd286104..b72485222e 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -1,166 +1,277 @@ #include "ck_tile/host.hpp" #include "layernorm2d_fwd.hpp" +#include #include -// Host API implementation -float layernorm2d_fwd(layernorm2d_fwd_traits t, - layernorm2d_fwd_args a, - const ck_tile::stream_config& s) +// different threshold for different dtype +template +auto get_elimit() { - if(t.data_type.compare("fp16") == 0) - { - using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; - using GammaDataType = ck_tile::half_t; - using BetaDataType = ck_tile::half_t; -#ifdef SAVE_MEAN_INV_STD - using MeanDataType = ck_tile::half_t; - using InvStdDataType = ck_tile::half_t; -#else - using MeanDataType = ck_tile::null_type; - using InvStdDataType = ck_tile::null_type; -#endif - using ComputeDataType = float; + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} - using thread_tile = ck_tile::sequence<4, 4>; - using warp_tile = ck_tile::sequence<8, 128>; - using block_tile = ck_tile::sequence<32, 128>; +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} - using Shape = ck_tile::TileLayernorm2dShape; - - using PipelineProblem = ck_tile::BlockLayernorm2dFwdProblem; - - using Kernel = ck_tile::Layernorm2dFwd; - - auto kargs = Kernel::MakeKargs( - a.p_x, a.p_gamma, a.p_beta, a.p_y, a.p_mean, a.p_invStd, a.epsilon, a.M, a.N); - - const dim3 grids = Kernel::GridSize(a.M); - constexpr dim3 blocks = Kernel::BlockSize(); - - constexpr ck_tile::index_t kBlockPerCu = Shape::kMWarpPerBlock * Shape::kNWarpPerBlock; - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - } - - return 0; +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1.0; + return ck_tile::make_tuple(rtol, atol); } auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") - .insert("n", "4096", "m dimension") + .insert("n", "4096", "n dimension") + .insert("x_stride", "-1", "x row_stride, if -1 then equal to n") + .insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n") + .insert("y_stride", "-1", "y row_stride, if -1 then equal to n") + .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n") .insert("e", "1e-5", "epsilon") + .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") - .insert("prec", "fp16", "precision"); + .insert("kname", "1", "print kernel name or not") + .insert("prec_i", "fp16", "input precision") + .insert("prec_o", "auto", "output precision, set auto will be the same as input") + .insert("prec_sm", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1") + .insert("prec_sy", + "auto", + "output quant scale type, set auto will use fp32. used when fquant=1 or 2") + .insert("xbias", "0", "add bias, 0:no add, 1:add bias before fadd") + .insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } -int main(int argc, char* argv[]) +template +bool run(const ck_tile::ArgParser& arg_parser) { + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride"); + if(xr_stride < 0) + xr_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; + ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride"); + if(yr_stride < 0) + yr_stride = n; + float epsilon = arg_parser.get_float("e"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sm = arg_parser.get_str("prec_sm"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sm == "auto") + { + prec_sm = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int xbias = arg_parser.get_int("xbias"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8") + { + std::cout + << "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases." + << std::endl; + return false; + } - float epsilon = arg_parser.get_float("e"); - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - std::string data_type = arg_parser.get_str("prec"); - int do_validation = arg_parser.get_int("v"); + assert(x_stride >= n); - using XDataType = ck_tile::half_t; - using YDataType = ck_tile::half_t; - using GammaDataType = ck_tile::half_t; - using BetaDataType = ck_tile::half_t; -#ifdef SAVE_MEAN_INV_STD - using MeanDataType = ck_tile::half_t; - using InvStdDataType = ck_tile::half_t; -#else - using MeanDataType = ck_tile::null_type; - using InvStdDataType = ck_tile::null_type; -#endif - using ComputeDataType = float; + using TypeConfig = + LayerNormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using XBiasDataType = typename TypeConfig::XBiasDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using BetaDataType = typename TypeConfig::BetaDataType; + using XResidualDataType = XDataType; + using YResidualDataType = XDataType; + + using MeanDataType = + std::conditional_t; + using InvStdDataType = + std::conditional_t; + + using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify - ck_tile::HostTensor x_host({M, N}); - ck_tile::HostTensor gamma_host({N}); - ck_tile::HostTensor beta_host({N}); + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); + ck_tile::HostTensor x_bias_host({n}); + ck_tile::HostTensor gamma_host({n}); + ck_tile::HostTensor beta_host({n}); - ck_tile::HostTensor y_host_ref({M, N}); - ck_tile::HostTensor y_host_dev({M, N}); + ck_tile::HostTensor x_residual_host({m, n}, {xr_stride, 1}); + ck_tile::HostTensor y_residual_host({m, n}, {yr_stride, 1}); - ck_tile::HostTensor mean_host_ref({M}); - ck_tile::HostTensor invStd_host_ref({M}); + ck_tile::HostTensor y_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {y_stride, 1}); -#ifdef SAVE_MEAN_INV_STD - ck_tile::HostTensor mean_host_dev({M}); - ck_tile::HostTensor invStd_host_dev({M}); -#endif + ck_tile::HostTensor mean_host_ref({m}); + ck_tile::HostTensor invStd_host_ref({m}); + ck_tile::HostTensor y_scale_host_ref({m}); + ck_tile::HostTensor y_scale_host_dev({m}); - ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(gamma_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(beta_host); + ck_tile::HostTensor sm_scale_host({n}); + ck_tile::HostTensor sm_scale_host_dev({n}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_residual_host); + ck_tile::FillUniformDistribution{-1.f, 1.f}(sm_scale_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(x_bias_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(beta_host); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes()); -#ifdef SAVE_MEAN_INV_STD - ck_tile::DeviceMem mean_buf(mean_host_dev.get_element_space_size_in_bytes()); - ck_tile::DeviceMem invStd_buf(invStd_host_dev.get_element_space_size_in_bytes()); -#endif + ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes()); x_buf.ToDevice(x_host.data()); + x_bias_buf.ToDevice(x_bias_host.data()); gamma_buf.ToDevice(gamma_host.data()); beta_buf.ToDevice(beta_host.data()); + x_residual_buf.ToDevice(x_residual_host.data()); + sm_scale_buf.ToDevice(sm_scale_host.data()); - layernorm2d_fwd_traits traits{data_type}; + auto prec_str = [&]() { + auto base_str = prec_i; + if(prec_i != prec_o) + { + base_str += "|" + prec_o; + } + if(fused_quant == 1) + { + base_str += std::string("(") + prec_sy + ")"; + } + return base_str; + }(); + + std::cout << "[" << prec_str << "]" + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride + << ", yr_stride:" << yr_stride << std::flush; + + layernorm2d_fwd_traits traits{ + prec_i, prec_o, prec_sm, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant}; layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr, + x_bias_buf.GetDeviceBuffer(), gamma_buf.GetDeviceBuffer(), beta_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), -#ifdef SAVE_MEAN_INV_STD - mean_buf.GetDeviceBuffer(), - invStd_buf.GetDeviceBuffer(), -#else - nullptr, - nullptr, -#endif + fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr, + fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr, + nullptr, // p_mean, unsupported yet + nullptr, // p_invStd, unsupported yet + epsilon, - M, - N}; + m, + n, + x_stride, // x row_stride + xr_stride, // x residule row stride + y_stride, // y row stride + yr_stride}; // y residule row stride - float ave_time = layernorm2d_fwd(traits, args, ck_tile::stream_config{nullptr, true}); + float ave_time = layernorm2d_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); - std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + - sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n + + sizeof(GammaDataType) * n + sizeof(BetaDataType) * n + + sizeof(YDataType) * m * n; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "[" << data_type << "]" - << " m:" << M << ", n:" << N << ", " << ave_time << " ms, " << gb_per_sec << " GB/s" - << std::flush; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; bool pass = true; if(do_validation) { // reference + if(xbias != 0) + { + // add bias before fadd + int M = x_host.mDesc.get_lengths()[0]; + int N = x_host.mDesc.get_lengths()[1]; + for(int idx_m = 0; idx_m < M; ++idx_m) + { + for(int idx_n = 0; idx_n < N; ++idx_n) + { + x_host(idx_m, idx_n) = ck_tile::type_convert( + ck_tile::type_convert(x_host(idx_m, idx_n)) + + ck_tile::type_convert(x_bias_host(idx_n))); + } + } + } + + if(fused_add != 0) + { + // fused pre_add/pre_add_store + // TODO we accumulate directly to x_host for simplcity here... + + std::transform(x_host.mData.cbegin(), + x_host.mData.cend(), + x_residual_host.mData.cbegin(), + x_host.mData.begin(), + [](auto x_, auto r_) { + auto o_ = ck_tile::type_convert(x_) + + ck_tile::type_convert(r_); + return ck_tile::type_convert(o_); + }); + } ck_tile::reference_layernorm2d_fwd( x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + if(fused_quant != 0) + { + auto dquant_functor = [&](int m_, auto& o_, auto& acc_) { + int N_ = acc_.mDesc.get_lengths()[1]; + if(fused_quant == 1) + { + for(int n_ = 0; n_ < N_; n_++) + { + // input smooth outlier + acc_(m_, n_) = acc_(m_, n_) * + ck_tile::type_convert(sm_scale_host(n_)); + } + } + ComputeDataType absmax = static_cast(0); + for(int n_ = 0; n_ < N_; n_++) + { + const auto a = ck_tile::abs(acc_(m_, n_)); + absmax = a > absmax ? a : absmax; + } + // printf("cpu:absmax:%f\n", absmax); + constexpr ComputeDataType kMaxY = + std::is_same::value ? 240.0 + : std::is_same::value ? 127.0 + : 0.0; + ComputeDataType y_scale = absmax / kMaxY; + y_scale_host_ref(m_) = ck_tile::type_convert(y_scale); + for(int n_ = 0; n_ < N_; n_++) + { + o_(m_, n_) = ck_tile::type_convert(acc_(m_, n_) / y_scale); + } + }; + + ck_tile::reference_layernorm2d_fwd(x_host, + gamma_host, + beta_host, + y_host_ref, + mean_host_ref, + invStd_host_ref, + epsilon, + dquant_functor); + } + else + { + ck_tile::reference_layernorm2d_fwd( + x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon); + } + y_buf.FromDevice(y_host_dev.data()); - pass = ck_tile::check_err(y_host_dev, y_host_ref); + ck_tile::HostTensor y_residual_host_dev({m, n}, {yr_stride, 1}); + if(fused_add == 1) + { + y_residual_buf.FromDevice(y_residual_host_dev.data()); + } -#ifdef SAVE_MEAN_INV_STD - mean_buf.FromDevice(mean_host_dev.data()); - pass &= ck_tile::check_err(mean_host_dev, mean_host_ref); + auto [rtol, atol] = get_elimit(); - invStd_buf.FromDevice(invStd_host_dev.data()); - pass &= ck_tile::check_err(invStd_host_dev, invStd_host_ref); -#endif + if(x_stride == n) + { + pass = ck_tile::check_err( + y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + if(fused_add == 1) + { + pass &= ck_tile::check_err(y_residual_host_dev, + x_host, + std::string("ADD Error: Incorrect results!"), + rtol, + atol); + } + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector y_host_dev_row(y_host_dev.begin() + i_r * y_stride, + y_host_dev.begin() + i_r * y_stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * y_stride, + y_host_ref.begin() + i_r * y_stride + n); + pass &= ck_tile::check_err(y_host_dev_row, + y_host_ref_row, + std::string("OUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + if(fused_add == 1) + { + std::vector y_residual_host_dev_row( + y_residual_host_dev.begin() + i_r * yr_stride, + y_residual_host_dev.begin() + i_r * yr_stride + n); + std::vector y_residual_host_ref_row( + x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n); + pass &= ck_tile::check_err(y_residual_host_dev_row, + y_residual_host_ref_row, + std::string("ADD[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + if(fused_quant == 1) + { + y_scale_buf.FromDevice(y_scale_host_dev.data()); + pass &= ck_tile::check_err(y_scale_host_dev, + y_scale_host_ref, + std::string("SCALE Error: Incorrect results!"), + rtol, + atol); + } - std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } - std::cout << std::endl << std::flush; - - return !pass; + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_sm = arg_parser.get_str("prec_sm"); + std::string prec_sy = arg_parser.get_str("prec_sy"); + + if(prec_o == "auto") + { + prec_o = prec_i; + } + if(prec_sm == "auto") + { + prec_sm = "fp32"; + } + if(prec_sy == "auto") + { + prec_sy = "fp32"; + } + int save_mv = arg_parser.get_int("save_mv"); + + // no dynamic quant case + if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && + save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + + // dynamic quant case, only in inference + else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" && + !save_mv) + { + return run(arg_parser) ? 0 : -2; + } + + return -3; } diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp index 4d1aac0994..0538953a58 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,23 +8,63 @@ #include "ck_tile/ops/layernorm2d.hpp" #include +template +struct LayerNormTypeConfig; + +template +struct LayerNormTypeConfig +{ + using XDataType = ck_tile::half_t; + using YDataType = OutType; + using XBiasDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using BetaDataType = ck_tile::half_t; + using MeanDataType = ck_tile::half_t; + using InvStdDataType = ck_tile::half_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; +}; + +template +struct LayerNormTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using YDataType = OutType; + using XBiasDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using BetaDataType = ck_tile::bf16_t; + using MeanDataType = ck_tile::bf16_t; + using InvStdDataType = ck_tile::bf16_t; + using ComputeDataType = float; + using SmoothScaleDataType = SmoothScaleDataType_; + using YScaleDataType = YScaleDataType_; +}; + +// runtime args +struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs +{ +}; + +// This is the public API, will be generated by script struct layernorm2d_fwd_traits { - std::string data_type; + std::string prec_i; // input precision + std::string prec_o; // output precision + + // if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set + // arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise + // can set arbitrary(will skip check) + std::string prec_sm; // x-scale, used for [1*N] input smooth quant + std::string prec_sy; // y-scale, used for [M*1] output for next layer + + bool save_mean_var; // + int xbias; // 0:no-bias, 1:add bias + int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant }; -struct layernorm2d_fwd_args -{ - const void* p_x; - const void* p_gamma; - const void* p_beta; - void* p_y; - void* p_mean; - void* p_invStd; - float epsilon; - ck_tile::index_t M; - ck_tile::index_t N; -}; - -// host API float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/02_layernorm2d/misc/dquant.png b/example/ck_tile/02_layernorm2d/misc/dquant.png new file mode 100644 index 0000000000..28b1a61a14 Binary files /dev/null and b/example/ck_tile/02_layernorm2d/misc/dquant.png differ diff --git a/example/ck_tile/02_layernorm2d/misc/pnorm.png b/example/ck_tile/02_layernorm2d/misc/pnorm.png new file mode 100644 index 0000000000..65a27e8751 Binary files /dev/null and b/example/ck_tile/02_layernorm2d/misc/pnorm.png differ diff --git a/example/ck_tile/02_layernorm2d/script/perf_test.sh b/example/ck_tile/02_layernorm2d/script/perf_test.sh new file mode 100755 index 0000000000..5a34e19280 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/script/perf_test.sh @@ -0,0 +1,37 @@ +#!/bin/sh +EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh new file mode 100755 index 0000000000..ceaf262bd9 --- /dev/null +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -0,0 +1,35 @@ +#!/bin/sh +EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" + +for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=1 -prec_o=fp8"; do +for pr_i in "fp16" "bf16" ; do +for fadd in "0" "1"; do +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=9120 +$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547 +#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 +done +done +done diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 03fc9c7eb1..3d3a54020c 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,2 +1,10 @@ -set(CMAKE_BUILD_TYPE Debug) -add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) \ No newline at end of file +add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) +add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) +add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 00303bf62c..da37159aeb 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -6,18 +6,32 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the gemm calculation make tile_example_gemm_basic -j +# The memory bound pipeline on the gemm calculation +make tile_example_gemm_universal -j ``` -This will result in an executable `build/bin/tile_example_gemm_basic` +This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal` ## example ``` args: - -m m dimension (default:3328) - -n m dimension (default:4096) + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) -k k dimension (default:64) - -e epsilon (default:1e-5) - -v cpu validation or not (default:1) - -prec precision (default:fp16) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) ``` diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 734ba0fe65..80c18cdb87 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,9 +1,7 @@ - // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "gemm_basic.hpp" -#include "ck_tile/host.hpp" +#include #include #include @@ -11,34 +9,37 @@ #include #include -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("b", "1", "batch size") - .insert("m", "1024", "m dimension") - .insert("n", "2048", "n dimension") - .insert("k", "64", "k dimension") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "1", "cpu validation or not") - .insert("e", "1e-5", "Absolute error tolerance") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") - .insert("warmup", "10", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) -template -float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { - // ToDo: This will be modified by the codegen code later. - constexpr ck_tile::index_t M_Tile = 128; - constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 32; + if constexpr(Persistent) + std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; @@ -46,229 +47,204 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 8; + constexpr ck_tile::index_t K_Warp_Tile = 16; - // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = false; - - constexpr int kBlockPerCu = 1; - - // =============================================== - - using GemmShape = + using CodegenGemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; - using TilePartitioner = ck_tile::GemmTilePartitioner; - using PipelineProblem = ck_tile:: - BlockGemmPipelineProblem; - // The GemmPipeline should also come from the Codegen. - using GemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = - ck_tile::GemmKernel; - auto kargs = Kernel::MakeKargs(args.p_a, - args.p_b, - args.p_c, - args.epsilon, - args.M, - args.N, - args.K, - args.stride_A, - args.stride_B, - args.stride_C); + using TilePartitioner = ck_tile::GemmTile1DPartitioner; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch); - constexpr dim3 blocks = Kernel::BlockSize(); + using CodegenGemmTraits = + ck_tile::TileGemmTraits; - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; - return ave_time; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC, + memory_operation>>; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } } -template -float invoke_gemm(ck_tile::DeviceMem& a_buf, - ck_tile::DeviceMem& b_buf, - ck_tile::DeviceMem& c_buf, - const ck_tile::ArgParser& arg_parser) +#include "run_gemm_example.inc" + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - std::string data_type = arg_parser.get_str("prec"); - - if(data_type != DataTypeTraits::name) + if constexpr(std::is_same_v) { - std::cerr << "Data type mismatch: expected " << DataTypeTraits::name << ", got " - << data_type << std::endl; - return -1; // Or handle the error appropriately - } - - float epsilon = arg_parser.get_float("e"); - ck_tile::index_t batch_size = arg_parser.get_int("b"); - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - - ck_tile::index_t stride_a = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_b = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_c = arg_parser.get_int("stride_c"); - - gemm_basic_args args; - args.p_a = a_buf.GetDeviceBuffer(); - args.p_b = b_buf.GetDeviceBuffer(); - args.p_c = c_buf.GetDeviceBuffer(); - args.epsilon = epsilon; - args.kbatch = batch_size; - args.M = M; - args.N = N; - args.K = K; - - // Only set stride_M and stride_N if they are non-zero and not equal to K. - if(stride_a != 0) - { - args.stride_A = stride_a; + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } } else { - args.stride_A = [&]() { - if constexpr(std::is_same_v) - { - return M; - } - else - { - return K; - } - }(); + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } } - - if(stride_b != 0) - { - args.stride_B = stride_b; - } - else - { - args.stride_B = [&]() { - if constexpr(std::is_same_v) - { - return N; - } - else - { - return K; - } - }(); - } - - if(stride_c != 0) - { - args.stride_C = stride_c; - } - else - { - args.stride_C = [&]() { - if constexpr(std::is_same_v) - { - return M; - } - else - { - return N; - } - }(); - } - - float ave_time = - gemm_calc(args, ck_tile::stream_config{nullptr, true}); - std::size_t num_byte = - sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "The overall perfomance of the GEMM with " - << "[" << data_type << "]" - << "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K - << "is: \n"; - std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n" - << std::flush; - - return ave_time; } -int main(int argc, char* argv[]) +int run_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); - // The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N). - using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor; - using matrix_b_layout = ck_tile::tensor_layout::gemm::RowMajor; - using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor; - - // host verify - std::vector a_dimensions = - (std::is_same_v) - ? std::vector{M, K} - : std::vector{K, M}; - std::vector b_dimensions = - (std::is_same_v) - ? std::vector{N, K} - : std::vector{K, N}; - std::vector c_dimensions = - (std::is_same_v) - ? std::vector{M, N} - : std::vector{N, M}; - - ck_tile::HostTensor a_host(a_dimensions); - ck_tile::HostTensor b_host(b_dimensions); - - ck_tile::HostTensor c_host_ref(c_dimensions); - ck_tile::HostTensor c_host_dev(c_dimensions); - - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); - - ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); - - a_buf.ToDevice(a_host.data()); - b_buf.ToDevice(b_host.data()); - - invoke_gemm( - a_buf, b_buf, c_buf, arg_parser); - - bool pass = true; - - if(arg_parser.get_bool("v")) + if(data_type == "fp16") { - // ToDo: Will Add the Element Op (bias) verification in the future. - ck_tile::reference_gemm(a_host, b_host, c_host_ref); - - c_buf.FromDevice(c_host_dev.data()); - - pass = ck_tile::check_err(c_host_dev, c_host_ref); - - std::cout << "The veification result is:" << (pass ? "correct" : "fail") << std::flush; + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "fp8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "i8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "pk_int4_t") + { + // TODO: Add support for bhalf_t ADataType + if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_example_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int main(int argc, char* argv[]) +{ + try + { + return !run_gemm_example(argc, argv); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; } - - std::cout << std::endl << std::flush; - - return !pass; } diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp deleted file mode 100644 index 28afb194c9..0000000000 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ /dev/null @@ -1,71 +0,0 @@ - -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include - -template -struct GemmBasicTypeConfig; - -template <> -struct GemmBasicTypeConfig -{ - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using AccDataType = float; - using CDataType = ck_tile::half_t; // type convert - // ToDo: Add more bias config to support different categories of GEMM. -}; - -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -using Types = GemmBasicTypeConfig; - -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; - -struct gemm_basic_args -{ - const void* p_a; - const void* p_b; - void* p_c; - float epsilon; - ck_tile::index_t kbatch; - ck_tile::index_t M; - ck_tile::index_t N; - ck_tile::index_t K; - ck_tile::index_t stride_A; - ck_tile::index_t stride_B; - ck_tile::index_t stride_C; -}; - -// host API -float gemm_calc(gemm_basic_args args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp new file mode 100644 index 0000000000..9deccc7f16 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE 5 + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(__gfx950__) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmConfigPreshufle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshufle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp new file mode 100644 index 0000000000..f57c24f458 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + auto [result, arg_parser] = create_args(argc, argv); + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && a_layout != "R" && b_layout != "C") + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } +} + +template