mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix and merge
This commit is contained in:
12
.github/CODEOWNERS
vendored
12
.github/CODEOWNERS
vendored
@@ -1,8 +1,8 @@
|
||||
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli
|
||||
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @shumway @vidyasagar-amd
|
||||
# Documentation files
|
||||
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd
|
||||
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd
|
||||
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd
|
||||
# Header directory for Doxygen documentation
|
||||
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli
|
||||
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -68,3 +68,6 @@ build*/
|
||||
|
||||
# Python cache
|
||||
__pycache__/
|
||||
|
||||
.cache/
|
||||
|
||||
|
||||
6
.pre-commit-config.yaml
Executable file → Normal file
6
.pre-commit-config.yaml
Executable file → Normal file
@@ -12,3 +12,9 @@ repos:
|
||||
verbose: false
|
||||
language: script
|
||||
types: [c++]
|
||||
- id: remove-exec-bit
|
||||
name: Remove executable bit from non-executable files
|
||||
entry: script/remove_exec_bit.sh
|
||||
language: script
|
||||
types_or: [c++, text]
|
||||
verbose: true
|
||||
|
||||
@@ -13,10 +13,16 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
|
||||
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
|
||||
* Added support for Stream-K version of mixed fp8/bf16 GEMM
|
||||
* Added GEMM pipeline for microscaling (MX) data types
|
||||
* Added support for Multiple D GEMM
|
||||
* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types
|
||||
* Added support for FP16 2:4 structured sparsity to universal GEMM.
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
* Added logit soft-capping support for fMHA forward kernels.
|
||||
* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv)
|
||||
* Added benchmarking support for tile engine GEMM.
|
||||
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
* Added int8 support for CK_TILE GEMM.
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
372
CMakeLists.txt
372
CMakeLists.txt
@@ -26,17 +26,21 @@ set(version 1.1.0)
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
|
||||
include(CTest)
|
||||
|
||||
option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
|
||||
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
|
||||
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
|
||||
# CK Codegen requires dataclass which is added in Python 3.7
|
||||
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
|
||||
if(NOT CK_USE_ALTERNATIVE_PYTHON)
|
||||
find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED)
|
||||
else()
|
||||
message("Using alternative python version")
|
||||
message(STATUS "Using alternative python version")
|
||||
set(EXTRA_PYTHON_PATH)
|
||||
# this is overly restrictive, we may need to be more flexible on the following
|
||||
string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}")
|
||||
message("alternative python path is: ${EXTRA_PYTHON_PATH}")
|
||||
message(STATUS "alternative python path is: ${EXTRA_PYTHON_PATH}")
|
||||
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
|
||||
add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}")
|
||||
set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}")
|
||||
@@ -76,7 +80,7 @@ if (DTYPES)
|
||||
add_definitions(-DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
message(STATUS "DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8)
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
@@ -94,6 +98,9 @@ add_compile_options(-Wno-pass-failed)
|
||||
add_compile_options(-Wno-switch-default)
|
||||
add_compile_options(-Wno-unique-object-duplication)
|
||||
|
||||
# Recent change in compiler makes this warning ON by default, which led to compile errors.
|
||||
add_compile_options(-Wno-nrvo)
|
||||
|
||||
if(NOT DISABLE_DL_KERNELS)
|
||||
add_definitions(-DDL_KERNELS)
|
||||
set(DL_KERNELS "ON")
|
||||
@@ -139,8 +146,8 @@ rocm_setup_version(VERSION ${version})
|
||||
|
||||
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}")
|
||||
|
||||
message("GPU_TARGETS= ${GPU_TARGETS}")
|
||||
message("GPU_ARCHS= ${GPU_ARCHS}")
|
||||
message(STATUS "GPU_TARGETS= ${GPU_TARGETS}")
|
||||
message(STATUS "GPU_ARCHS= ${GPU_ARCHS}")
|
||||
if(GPU_ARCHS)
|
||||
#disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package
|
||||
unset(GPU_TARGETS CACHE)
|
||||
@@ -155,9 +162,9 @@ find_package(hip REQUIRED)
|
||||
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
|
||||
# SWDEV-413293 and https://reviews.llvm.org/D155213
|
||||
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
|
||||
message("hip_version_flat=${hip_VERSION_FLAT}")
|
||||
message(STATUS "hip_version_flat=${hip_VERSION_FLAT}")
|
||||
|
||||
message("checking which targets are supported")
|
||||
message(STATUS "checking which targets are supported")
|
||||
#In order to build just the CK library (without tests and examples) for all supported GPU targets
|
||||
#use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts.
|
||||
@@ -169,8 +176,10 @@ if(NOT ENABLE_ASAN_PACKAGING)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600300000 AND ${hip_VERSION_FLAT} LESS 600400000)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
|
||||
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000)
|
||||
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000 AND ${hip_VERSION_FLAT} LESS 600443483)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950")
|
||||
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600443483)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx950;gfx10-3-generic;gfx11-generic;gfx12-generic")
|
||||
endif()
|
||||
else()
|
||||
#build CK only for xnack-supported targets when using ASAN
|
||||
@@ -194,25 +203,25 @@ endif()
|
||||
rocm_check_target_ids(SUPPORTED_GPU_TARGETS
|
||||
TARGETS ${CK_GPU_TARGETS})
|
||||
|
||||
message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
|
||||
message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
message("Enabling XDL instances")
|
||||
message(STATUS "Enabling XDL instances")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
message("Enabling XDL FP8 gemms on native architectures")
|
||||
message(STATUS "Enabling XDL FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
set(CK_USE_GFX94 "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message("Enabling WMMA instances")
|
||||
message(STATUS "Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message("Enabling WMMA FP8 gemms on native architectures")
|
||||
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_WMMA_FP8)
|
||||
set(CK_USE_WMMA_FP8 "ON")
|
||||
endif()
|
||||
@@ -241,32 +250,32 @@ configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/con
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302)
|
||||
check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK)
|
||||
if(HAS_NO_OFFLOAD_UNIFORM_BLOCK)
|
||||
message("Adding the fno-offload-uniform-block compiler flag")
|
||||
message(STATUS "Adding the fno-offload-uniform-block compiler flag")
|
||||
add_compile_options(-fno-offload-uniform-block)
|
||||
endif()
|
||||
endif()
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000)
|
||||
check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION)
|
||||
if(HAS_LSR_DROP_SOLUTION)
|
||||
message("Adding the lsr-drop-solution=1 compiler flag")
|
||||
message(STATUS "Adding the lsr-drop-solution=1 compiler flag")
|
||||
add_compile_options("SHELL: -mllvm --lsr-drop-solution=1")
|
||||
endif()
|
||||
endif()
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
|
||||
check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED)
|
||||
if(HAS_ENABLE_POST_MISCHED)
|
||||
message("Adding the enable-post-misched=0 compiler flag")
|
||||
message(STATUS "Adding the enable-post-misched=0 compiler flag")
|
||||
add_compile_options("SHELL: -mllvm -enable-post-misched=0")
|
||||
endif()
|
||||
endif()
|
||||
set(check-coerce)
|
||||
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
|
||||
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
|
||||
message("Adding the amdgpu-coerce-illegal-types=1")
|
||||
message(STATUS "Adding the amdgpu-coerce-illegal-types=1")
|
||||
add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1")
|
||||
endif()
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
|
||||
message("Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false")
|
||||
message(STATUS "Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false")
|
||||
add_compile_options("SHELL: -mllvm -amdgpu-early-inline-all=true")
|
||||
add_compile_options("SHELL: -mllvm -amdgpu-function-calls=false")
|
||||
endif()
|
||||
@@ -299,17 +308,24 @@ endif()
|
||||
|
||||
option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
|
||||
option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF)
|
||||
option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
add_compile_options(-Wno-bit-int-extension)
|
||||
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
|
||||
message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
|
||||
endif()
|
||||
|
||||
if(USE_OPT_GFX11)
|
||||
add_compile_options(-mcumode)
|
||||
add_compile_options(-mno-wavefrontsize64)
|
||||
message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
|
||||
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
|
||||
endif()
|
||||
|
||||
if(ENABLE_ASM_DUMP)
|
||||
add_compile_options(--save-temps)
|
||||
add_compile_options(-Wno-gnu-line-marker)
|
||||
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
|
||||
endif()
|
||||
|
||||
## Threads
|
||||
@@ -321,7 +337,7 @@ link_libraries(Threads::Threads)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
|
||||
message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
|
||||
|
||||
# https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_macros.html
|
||||
# _GLIBCXX_ASSERTIONS
|
||||
@@ -337,7 +353,7 @@ endif()
|
||||
set(CMAKE_HIP_PLATFORM amd)
|
||||
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
|
||||
set(CMAKE_HIP_EXTENSIONS ON)
|
||||
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
|
||||
message(STATUS "CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
|
||||
|
||||
## OpenMP
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
@@ -352,10 +368,10 @@ else()
|
||||
find_package(OpenMP REQUIRED)
|
||||
endif()
|
||||
|
||||
message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
|
||||
message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
|
||||
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
|
||||
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
|
||||
message(STATUS "OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
|
||||
message(STATUS "OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
|
||||
message(STATUS "OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
|
||||
message(STATUS "OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
|
||||
|
||||
link_libraries(${OpenMP_gomp_LIBRARY})
|
||||
link_libraries(${OpenMP_pthread_LIBRARY})
|
||||
@@ -387,146 +403,152 @@ else()
|
||||
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
|
||||
endif()
|
||||
|
||||
## tidy
|
||||
include(EnableCompilerWarnings)
|
||||
## tidy
|
||||
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
|
||||
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
|
||||
set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
|
||||
set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
|
||||
# Enable tidy on hip
|
||||
elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU")
|
||||
set(CK_TIDY_ERRORS ALL)
|
||||
set(CK_TIDY_ERRORS ALL)
|
||||
endif()
|
||||
|
||||
|
||||
include(ClangTidy)
|
||||
enable_clang_tidy(
|
||||
CHECKS
|
||||
*
|
||||
-abseil-*
|
||||
-android-cloexec-fopen
|
||||
# Yea we shouldn't be using rand()
|
||||
-cert-msc30-c
|
||||
-bugprone-exception-escape
|
||||
-bugprone-macro-parentheses
|
||||
-cert-env33-c
|
||||
-cert-msc32-c
|
||||
-cert-msc50-cpp
|
||||
-cert-msc51-cpp
|
||||
-cert-dcl37-c
|
||||
-cert-dcl51-cpp
|
||||
-clang-analyzer-alpha.core.CastToStruct
|
||||
-clang-analyzer-optin.performance.Padding
|
||||
-clang-diagnostic-deprecated-declarations
|
||||
-clang-diagnostic-extern-c-compat
|
||||
-clang-diagnostic-unused-command-line-argument
|
||||
-cppcoreguidelines-avoid-c-arrays
|
||||
-cppcoreguidelines-avoid-magic-numbers
|
||||
-cppcoreguidelines-explicit-virtual-functions
|
||||
-cppcoreguidelines-init-variables
|
||||
-cppcoreguidelines-macro-usage
|
||||
-cppcoreguidelines-non-private-member-variables-in-classes
|
||||
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
|
||||
-cppcoreguidelines-pro-bounds-constant-array-index
|
||||
-cppcoreguidelines-pro-bounds-pointer-arithmetic
|
||||
-cppcoreguidelines-pro-type-member-init
|
||||
-cppcoreguidelines-pro-type-reinterpret-cast
|
||||
-cppcoreguidelines-pro-type-union-access
|
||||
-cppcoreguidelines-pro-type-vararg
|
||||
-cppcoreguidelines-special-member-functions
|
||||
-fuchsia-*
|
||||
-google-explicit-constructor
|
||||
-google-readability-braces-around-statements
|
||||
-google-readability-todo
|
||||
-google-runtime-int
|
||||
-google-runtime-references
|
||||
-hicpp-vararg
|
||||
-hicpp-braces-around-statements
|
||||
-hicpp-explicit-conversions
|
||||
-hicpp-named-parameter
|
||||
-hicpp-no-array-decay
|
||||
# We really shouldn't use bitwise operators with signed integers, but
|
||||
# opencl leaves us no choice
|
||||
-hicpp-avoid-c-arrays
|
||||
-hicpp-signed-bitwise
|
||||
-hicpp-special-member-functions
|
||||
-hicpp-uppercase-literal-suffix
|
||||
-hicpp-use-auto
|
||||
-hicpp-use-equals-default
|
||||
-hicpp-use-override
|
||||
-llvm-header-guard
|
||||
-llvm-include-order
|
||||
#-llvmlibc-*
|
||||
-llvmlibc-restrict-system-libc-headers
|
||||
-llvmlibc-callee-namespace
|
||||
-llvmlibc-implementation-in-namespace
|
||||
-llvm-else-after-return
|
||||
-llvm-qualified-auto
|
||||
-misc-misplaced-const
|
||||
-misc-non-private-member-variables-in-classes
|
||||
-misc-no-recursion
|
||||
-modernize-avoid-bind
|
||||
-modernize-avoid-c-arrays
|
||||
-modernize-pass-by-value
|
||||
-modernize-use-auto
|
||||
-modernize-use-default-member-init
|
||||
-modernize-use-equals-default
|
||||
-modernize-use-trailing-return-type
|
||||
-modernize-use-transparent-functors
|
||||
-performance-unnecessary-value-param
|
||||
-readability-braces-around-statements
|
||||
-readability-else-after-return
|
||||
# we are not ready to use it, but very useful
|
||||
-readability-function-cognitive-complexity
|
||||
-readability-isolate-declaration
|
||||
-readability-magic-numbers
|
||||
-readability-named-parameter
|
||||
-readability-uppercase-literal-suffix
|
||||
-readability-convert-member-functions-to-static
|
||||
-readability-qualified-auto
|
||||
-readability-redundant-string-init
|
||||
# too many narrowing conversions in our code
|
||||
-bugprone-narrowing-conversions
|
||||
-cppcoreguidelines-narrowing-conversions
|
||||
-altera-struct-pack-align
|
||||
-cppcoreguidelines-prefer-member-initializer
|
||||
${CK_TIDY_CHECKS}
|
||||
${CK_TIDY_ERRORS}
|
||||
HEADER_FILTER
|
||||
"\.hpp$"
|
||||
EXTRA_ARGS
|
||||
-DCK_USE_CLANG_TIDY
|
||||
)
|
||||
if(ENABLE_CLANG_CPP_CHECKS)
|
||||
include(ClangTidy)
|
||||
enable_clang_tidy(
|
||||
CHECKS
|
||||
*
|
||||
-abseil-*
|
||||
-android-cloexec-fopen
|
||||
# Yea we shouldn't be using rand()
|
||||
-cert-msc30-c
|
||||
-bugprone-exception-escape
|
||||
-bugprone-macro-parentheses
|
||||
-cert-env33-c
|
||||
-cert-msc32-c
|
||||
-cert-msc50-cpp
|
||||
-cert-msc51-cpp
|
||||
-cert-dcl37-c
|
||||
-cert-dcl51-cpp
|
||||
-clang-analyzer-alpha.core.CastToStruct
|
||||
-clang-analyzer-optin.performance.Padding
|
||||
-clang-diagnostic-deprecated-declarations
|
||||
-clang-diagnostic-extern-c-compat
|
||||
-clang-diagnostic-unused-command-line-argument
|
||||
-cppcoreguidelines-avoid-c-arrays
|
||||
-cppcoreguidelines-avoid-magic-numbers
|
||||
-cppcoreguidelines-explicit-virtual-functions
|
||||
-cppcoreguidelines-init-variables
|
||||
-cppcoreguidelines-macro-usage
|
||||
-cppcoreguidelines-non-private-member-variables-in-classes
|
||||
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
|
||||
-cppcoreguidelines-pro-bounds-constant-array-index
|
||||
-cppcoreguidelines-pro-bounds-pointer-arithmetic
|
||||
-cppcoreguidelines-pro-type-member-init
|
||||
-cppcoreguidelines-pro-type-reinterpret-cast
|
||||
-cppcoreguidelines-pro-type-union-access
|
||||
-cppcoreguidelines-pro-type-vararg
|
||||
-cppcoreguidelines-special-member-functions
|
||||
-fuchsia-*
|
||||
-google-explicit-constructor
|
||||
-google-readability-braces-around-statements
|
||||
-google-readability-todo
|
||||
-google-runtime-int
|
||||
-google-runtime-references
|
||||
-hicpp-vararg
|
||||
-hicpp-braces-around-statements
|
||||
-hicpp-explicit-conversions
|
||||
-hicpp-named-parameter
|
||||
-hicpp-no-array-decay
|
||||
# We really shouldn't use bitwise operators with signed integers, but
|
||||
# opencl leaves us no choice
|
||||
-hicpp-avoid-c-arrays
|
||||
-hicpp-signed-bitwise
|
||||
-hicpp-special-member-functions
|
||||
-hicpp-uppercase-literal-suffix
|
||||
-hicpp-use-auto
|
||||
-hicpp-use-equals-default
|
||||
-hicpp-use-override
|
||||
-llvm-header-guard
|
||||
-llvm-include-order
|
||||
#-llvmlibc-*
|
||||
-llvmlibc-restrict-system-libc-headers
|
||||
-llvmlibc-callee-namespace
|
||||
-llvmlibc-implementation-in-namespace
|
||||
-llvm-else-after-return
|
||||
-llvm-qualified-auto
|
||||
-misc-misplaced-const
|
||||
-misc-non-private-member-variables-in-classes
|
||||
-misc-no-recursion
|
||||
-modernize-avoid-bind
|
||||
-modernize-avoid-c-arrays
|
||||
-modernize-pass-by-value
|
||||
-modernize-use-auto
|
||||
-modernize-use-default-member-init
|
||||
-modernize-use-equals-default
|
||||
-modernize-use-trailing-return-type
|
||||
-modernize-use-transparent-functors
|
||||
-performance-unnecessary-value-param
|
||||
-readability-braces-around-statements
|
||||
-readability-else-after-return
|
||||
# we are not ready to use it, but very useful
|
||||
-readability-function-cognitive-complexity
|
||||
-readability-isolate-declaration
|
||||
-readability-magic-numbers
|
||||
-readability-named-parameter
|
||||
-readability-uppercase-literal-suffix
|
||||
-readability-convert-member-functions-to-static
|
||||
-readability-qualified-auto
|
||||
-readability-redundant-string-init
|
||||
# too many narrowing conversions in our code
|
||||
-bugprone-narrowing-conversions
|
||||
-cppcoreguidelines-narrowing-conversions
|
||||
-altera-struct-pack-align
|
||||
-cppcoreguidelines-prefer-member-initializer
|
||||
${CK_TIDY_CHECKS}
|
||||
${CK_TIDY_ERRORS}
|
||||
HEADER_FILTER
|
||||
"\.hpp$"
|
||||
EXTRA_ARGS
|
||||
-DCK_USE_CLANG_TIDY
|
||||
)
|
||||
|
||||
include(CppCheck)
|
||||
enable_cppcheck(
|
||||
CHECKS
|
||||
warning
|
||||
style
|
||||
performance
|
||||
portability
|
||||
SUPPRESS
|
||||
ConfigurationNotChecked
|
||||
constStatement
|
||||
duplicateCondition
|
||||
noExplicitConstructor
|
||||
passedByValue
|
||||
preprocessorErrorDirective
|
||||
shadowVariable
|
||||
unusedFunction
|
||||
unusedPrivateFunction
|
||||
unusedStructMember
|
||||
unmatchedSuppression
|
||||
FORCE
|
||||
SOURCES
|
||||
library/src
|
||||
INCLUDE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/library/include
|
||||
DEFINE
|
||||
CPPCHECK=1
|
||||
__linux__=1
|
||||
)
|
||||
include(CppCheck)
|
||||
enable_cppcheck(
|
||||
CHECKS
|
||||
warning
|
||||
style
|
||||
performance
|
||||
portability
|
||||
SUPPRESS
|
||||
ConfigurationNotChecked
|
||||
constStatement
|
||||
duplicateCondition
|
||||
noExplicitConstructor
|
||||
passedByValue
|
||||
preprocessorErrorDirective
|
||||
shadowVariable
|
||||
unusedFunction
|
||||
unusedPrivateFunction
|
||||
unusedStructMember
|
||||
unmatchedSuppression
|
||||
FORCE
|
||||
SOURCES
|
||||
library/src
|
||||
INCLUDE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/library/include
|
||||
DEFINE
|
||||
CPPCHECK=1
|
||||
__linux__=1
|
||||
)
|
||||
else()
|
||||
function(clang_tidy_check TARGET)
|
||||
# stub out empty function if clang tidy is not enabled
|
||||
endfunction()
|
||||
endif()
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
@@ -545,7 +567,7 @@ if(BUILD_DEV)
|
||||
add_compile_options(-Werror)
|
||||
add_compile_options(-Weverything)
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
|
||||
add_compile_options(-fcolor-diagnostics)
|
||||
@@ -554,12 +576,15 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS
|
||||
add_compile_options(-fdiagnostics-color=always)
|
||||
endif()
|
||||
|
||||
# make check runs the entire set of examples and tests
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
|
||||
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
|
||||
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
|
||||
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
|
||||
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
|
||||
if(NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
# make check runs the entire set of examples and tests
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
|
||||
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
|
||||
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
|
||||
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
|
||||
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
|
||||
@@ -602,9 +627,14 @@ ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
|
||||
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
|
||||
add_subdirectory(library)
|
||||
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
@@ -621,11 +651,13 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
if (NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
endif()
|
||||
|
||||
if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
|
||||
add_subdirectory(codegen)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
FROM ubuntu:24.04
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG ROCMVERSION=6.4
|
||||
ARG ROCMVERSION=6.4.1
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
ARG CK_SCCACHE=""
|
||||
@@ -13,8 +13,8 @@ RUN set -xe && \
|
||||
curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
|
||||
|
||||
RUN if [ "$ROCMVERSION" != "6.5" ]; then \
|
||||
sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60400-1_all.deb --no-check-certificate" && \
|
||||
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60400-1_all.deb && \
|
||||
sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \
|
||||
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \
|
||||
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
|
||||
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \
|
||||
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4"
|
||||
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1"
|
||||
FROM $BASE_DOCKER
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
|
||||
382
Jenkinsfile
vendored
382
Jenkinsfile
vendored
@@ -12,6 +12,23 @@ def show_node_info() {
|
||||
"""
|
||||
}
|
||||
|
||||
class Version {
|
||||
int major, minor, patch
|
||||
@Override
|
||||
String toString() {
|
||||
return [major, minor, patch].findAll().join('.')
|
||||
}
|
||||
}
|
||||
def parseVersion(String versionString) {
|
||||
if (!versionString) return null
|
||||
int[] tokens = versionString.split(/\./).collect { it as int } // Splits the string by '.' and converts each part to an integer.
|
||||
return new Version(
|
||||
major: tokens[0],
|
||||
minor: tokens.length > 1 ? tokens[1] : null,
|
||||
patch: tokens.length > 2 ? tokens[2] : null,
|
||||
)
|
||||
}
|
||||
|
||||
def nthreads() {
|
||||
def nproc = sh(returnStdout: true, script: 'nproc')
|
||||
echo "Number of cores: ${nproc}"
|
||||
@@ -38,8 +55,8 @@ def getBaseDockerImageName(){
|
||||
img = "${params.USE_CUSTOM_DOCKER}"
|
||||
}
|
||||
else{
|
||||
def ROCM_numeric = "${params.ROCMVERSION}" as float
|
||||
if ( ROCM_numeric < 6.5 ){
|
||||
def ROCM_numeric = parseVersion("${params.ROCMVERSION}")
|
||||
if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}"
|
||||
}
|
||||
else{
|
||||
@@ -93,6 +110,33 @@ def build_compiler(){
|
||||
return compiler
|
||||
}
|
||||
|
||||
def check_arch(){
|
||||
def arch_type = 0
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
|
||||
arch_type = 1
|
||||
}
|
||||
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
|
||||
arch_type = 2
|
||||
}
|
||||
else if ( runShell('grep -n "gfx10" rocminfo.log') ) {
|
||||
arch_type = 3
|
||||
}
|
||||
else if ( runShell('grep -n "gfx11" rocminfo.log') ) {
|
||||
arch_type = 4
|
||||
}
|
||||
else if ( runShell('grep -n "gfx12" rocminfo.log') ) {
|
||||
arch_type = 5
|
||||
}
|
||||
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
|
||||
arch_type = 6
|
||||
}
|
||||
else if ( runShell('grep -n "gfx950" rocminfo.log') ) {
|
||||
arch_type = 7
|
||||
}
|
||||
return arch_type
|
||||
}
|
||||
|
||||
def getDockerImage(Map conf=[:]){
|
||||
env.DOCKER_BUILDKIT=1
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
@@ -108,6 +152,10 @@ def getDockerImage(Map conf=[:]){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using special docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
@@ -177,13 +225,20 @@ def cmake_build(Map conf=[:]){
|
||||
def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","")
|
||||
def prefixpath = conf.get("prefixpath","/opt/rocm")
|
||||
def setup_args = conf.get("setup_args","")
|
||||
|
||||
// make sure all unit tests always run on develop branch
|
||||
def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS
|
||||
|
||||
if (prefixpath != "/usr/local"){
|
||||
setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} "
|
||||
}
|
||||
|
||||
def build_type_debug = (conf.get("build_type",'release') == 'debug')
|
||||
|
||||
// use special compiler for gfx950
|
||||
if ( check_arch() == 7){
|
||||
compiler = "/llvm-project/build/bin/clang++"
|
||||
}
|
||||
|
||||
//cmake_env can overwrite default CXX variables.
|
||||
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
|
||||
|
||||
@@ -239,6 +294,9 @@ def cmake_build(Map conf=[:]){
|
||||
if (setup_args.contains("gfx94")){
|
||||
invocation_tag="gfx94"
|
||||
}
|
||||
if (setup_args.contains("gfx95")){
|
||||
invocation_tag="gfx95"
|
||||
}
|
||||
echo "invocation tag: ${invocation_tag}"
|
||||
def redis_pre_setup_cmd = pre_setup_cmd
|
||||
if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") {
|
||||
@@ -287,15 +345,8 @@ def cmake_build(Map conf=[:]){
|
||||
def build_cmd
|
||||
def execute_cmd = conf.get("execute_cmd", "")
|
||||
if(!setup_args.contains("NO_CK_BUILD")){
|
||||
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
|
||||
echo "running ninja build trace"
|
||||
setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """)
|
||||
build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}")
|
||||
}
|
||||
else{
|
||||
setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
|
||||
build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}")
|
||||
}
|
||||
setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """)
|
||||
build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}")
|
||||
cmd = conf.get("cmd", """
|
||||
${setup_cmd}
|
||||
${build_cmd}
|
||||
@@ -315,7 +366,7 @@ def cmake_build(Map conf=[:]){
|
||||
sh cmd
|
||||
//run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set
|
||||
if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){
|
||||
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
|
||||
if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){
|
||||
sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log"
|
||||
sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log"
|
||||
@@ -323,13 +374,31 @@ def cmake_build(Map conf=[:]){
|
||||
archiveArtifacts "clang_build_analysis.log"
|
||||
// do not run unit tests when building instances only
|
||||
if(!params.BUILD_INSTANCES_ONLY){
|
||||
sh "ninja test"
|
||||
if (!runAllUnitTests){
|
||||
sh "../script/launch_tests.sh"
|
||||
}
|
||||
else{
|
||||
sh "ninja check"
|
||||
}
|
||||
}
|
||||
if(params.BUILD_INSTANCES_ONLY){
|
||||
// build deb packages
|
||||
echo "Build packages"
|
||||
sh 'ninja -j64 package'
|
||||
archiveArtifacts artifacts: 'composablekernel-dev*.deb'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb'
|
||||
stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages"
|
||||
}
|
||||
}
|
||||
else{
|
||||
// run unit tests unless building library for all targets
|
||||
if (!params.BUILD_INSTANCES_ONLY){
|
||||
sh "make check"
|
||||
if (!runAllUnitTests){
|
||||
sh "../script/launch_tests.sh"
|
||||
}
|
||||
else{
|
||||
sh "ninja check"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -340,21 +409,14 @@ def cmake_build(Map conf=[:]){
|
||||
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
|
||||
}
|
||||
//check the node gpu architecture
|
||||
def arch_type = 0
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
|
||||
arch_type = 1
|
||||
}
|
||||
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
|
||||
arch_type = 2
|
||||
}
|
||||
def arch = check_arch()
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_fmha_*.log"
|
||||
if (arch_type == 1){
|
||||
if (arch == 1){
|
||||
stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a"
|
||||
}
|
||||
else if (arch_type == 2){
|
||||
else if (arch == 2){
|
||||
stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942"
|
||||
}
|
||||
}
|
||||
@@ -379,10 +441,10 @@ def cmake_build(Map conf=[:]){
|
||||
if (params.RUN_CK_TILE_GEMM_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_tile_gemm_**.log"
|
||||
if (arch_type == 1){
|
||||
if (arch == 1){
|
||||
stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
|
||||
}
|
||||
else if (arch_type == 2){
|
||||
else if (arch == 2){
|
||||
stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942"
|
||||
}
|
||||
}
|
||||
@@ -397,20 +459,16 @@ def buildHipClangJob(Map conf=[:]){
|
||||
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
checkout scm
|
||||
|
||||
def image
|
||||
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
}
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts
|
||||
if ( params.BUILD_INSTANCES_ONLY ){
|
||||
dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
}
|
||||
else{
|
||||
dockerOpts = "--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
}
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -424,7 +482,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
echo "Docker flags: ${dockerOpts}"
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
|
||||
def image
|
||||
def retimage
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
|
||||
@@ -465,17 +523,6 @@ def Build_CK(Map conf=[:]){
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
env.DOCKER_BUILDKIT=1
|
||||
checkout scm
|
||||
|
||||
def image
|
||||
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
}
|
||||
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
@@ -496,6 +543,7 @@ def Build_CK(Map conf=[:]){
|
||||
echo "Docker flags: ${dockerOpts}"
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
def image
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
@@ -521,28 +569,9 @@ def Build_CK(Map conf=[:]){
|
||||
timeout(time: 20, unit: 'HOURS')
|
||||
{
|
||||
//check whether to run performance tests on this node
|
||||
def arch_type = 0
|
||||
sh 'rocminfo | tee rocminfo.log'
|
||||
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
|
||||
arch_type = 1
|
||||
}
|
||||
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
|
||||
arch_type = 2
|
||||
}
|
||||
else if ( runShell('grep -n "gfx10" rocminfo.log') ) {
|
||||
arch_type = 3
|
||||
}
|
||||
else if ( runShell('grep -n "gfx11" rocminfo.log') ) {
|
||||
arch_type = 4
|
||||
}
|
||||
else if ( runShell('grep -n "gfx12" rocminfo.log') ) {
|
||||
arch_type = 5
|
||||
}
|
||||
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
|
||||
arch_type = 6
|
||||
}
|
||||
def arch = check_arch()
|
||||
cmake_build(conf)
|
||||
if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch_type == 1 ){
|
||||
if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch == 1 ){
|
||||
echo "Run inductor codegen tests"
|
||||
sh """
|
||||
python3 -m venv ${env.WORKSPACE}
|
||||
@@ -553,9 +582,9 @@ def Build_CK(Map conf=[:]){
|
||||
"""
|
||||
}
|
||||
dir("build"){
|
||||
if (params.RUN_FULL_QA && arch_type == 2 ){
|
||||
// build deb packages for all gfx9 targets on gfx90a system and prepare to export
|
||||
echo "Build ckProfiler package"
|
||||
if (params.RUN_FULL_QA && arch == 2 ){
|
||||
// build deb packages
|
||||
echo "Build packages"
|
||||
sh 'make -j package'
|
||||
archiveArtifacts artifacts: 'composablekernel*.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb'
|
||||
@@ -568,7 +597,7 @@ def Build_CK(Map conf=[:]){
|
||||
// run performance tests, stash the logs, results will be processed on the master node
|
||||
dir("script"){
|
||||
if (params.RUN_PERFORMANCE_TESTS){
|
||||
if (params.RUN_FULL_QA && arch_type == 1){
|
||||
if (params.RUN_FULL_QA && arch == 1){
|
||||
// run full tests on gfx90a
|
||||
echo "Run full performance tests"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
@@ -587,7 +616,7 @@ def Build_CK(Map conf=[:]){
|
||||
archiveArtifacts "perf_mixed_gemm.log"
|
||||
stash includes: "perf_**.log", name: "perf_log"
|
||||
}
|
||||
else if ( arch_type == 1 ){
|
||||
else if ( arch == 1 ){
|
||||
// run standard tests on gfx90a
|
||||
echo "Run performance tests"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
@@ -598,37 +627,44 @@ def Build_CK(Map conf=[:]){
|
||||
stash includes: "perf_**.log", name: "perf_log"
|
||||
}
|
||||
// disable performance tests on gfx1030 for now.
|
||||
//else if ( arch_type == 3){
|
||||
//else if ( arch == 3){
|
||||
// run basic tests on gfx1030
|
||||
// echo "Run gemm performance tests"
|
||||
// sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10"
|
||||
// archiveArtifacts "perf_onnx_gemm_gfx10.log"
|
||||
// stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10"
|
||||
//}
|
||||
else if ( arch_type == 4){
|
||||
else if ( arch == 4){
|
||||
// run basic tests on gfx11
|
||||
echo "Run gemm performance tests"
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx11.log"
|
||||
stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11"
|
||||
}
|
||||
else if ( arch_type == 5 ){
|
||||
else if ( arch == 5 ){
|
||||
// run basic tests on gfx12
|
||||
echo "Run gemm performance tests"
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx12.log"
|
||||
stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12"
|
||||
}
|
||||
else if ( arch_type == 6 ){
|
||||
else if ( arch == 6 ){
|
||||
// run basic tests on gfx908
|
||||
echo "Run performance tests"
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx908.log"
|
||||
stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908"
|
||||
}
|
||||
else if ( arch == 7 ){
|
||||
// run basic tests on gfx950
|
||||
echo "Run performance tests"
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx950.log"
|
||||
stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950"
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params.hipTensor_test && arch_type == 1 ){
|
||||
if (params.hipTensor_test && arch == 1 ){
|
||||
// build and test hipTensor on gfx90a node
|
||||
sh """#!/bin/bash
|
||||
rm -rf "${params.hipTensor_branch}".zip
|
||||
@@ -730,24 +766,10 @@ def process_results(Map conf=[:]){
|
||||
echo "could not locate the GEMM performance logs: ${err.getMessage()}."
|
||||
}
|
||||
}
|
||||
if (params.RUN_FULL_QA){
|
||||
// unstash perf files to master
|
||||
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
|
||||
// unstash deb packages
|
||||
unstash "packages"
|
||||
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
|
||||
try{
|
||||
unstash "perf_log"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate perf_log: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx11"
|
||||
unstash "perf_log_gfx12"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
sh "./process_qa_data.sh"
|
||||
}
|
||||
else{
|
||||
// unstash perf files to master
|
||||
@@ -775,12 +797,12 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
|
||||
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
|
||||
0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : ""
|
||||
|
||||
pipeline {
|
||||
@@ -802,8 +824,8 @@ pipeline {
|
||||
description: 'If you want to use a custom docker image, please specify it here (default: leave blank).')
|
||||
string(
|
||||
name: 'ROCMVERSION',
|
||||
defaultValue: '6.4',
|
||||
description: 'Specify which ROCM version to use: 6.3 (default).')
|
||||
defaultValue: '6.4.1',
|
||||
description: 'Specify which ROCM version to use: 6.4.1 (default).')
|
||||
string(
|
||||
name: 'COMPILER_VERSION',
|
||||
defaultValue: '',
|
||||
@@ -842,16 +864,16 @@ pipeline {
|
||||
description: "Run the cppcheck static analysis (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_PERFORMANCE_TESTS",
|
||||
defaultValue: true,
|
||||
description: "Run the performance tests (default: ON)")
|
||||
defaultValue: false,
|
||||
description: "Run the performance tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the grouped conv large cases tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CODEGEN_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run codegen tests (default: OFF)")
|
||||
defaultValue: true,
|
||||
description: "Run codegen tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "RUN_CK_TILE_FMHA_TESTS",
|
||||
defaultValue: false,
|
||||
@@ -864,6 +886,10 @@ pipeline {
|
||||
name: "RUN_CK_TILE_GEMM_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the ck_tile GEMM tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_TILE_ENGINE_GEMM_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the tile_engine_gemm tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_INSTANCES_ONLY",
|
||||
defaultValue: false,
|
||||
@@ -872,6 +898,26 @@ pipeline {
|
||||
name: "BUILD_GFX908",
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx908 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX90A",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx90a (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX942",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx942 (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX950",
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx950 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX10",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx10 (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX11",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx11 (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX12",
|
||||
defaultValue: true,
|
||||
@@ -888,6 +934,10 @@ pipeline {
|
||||
name: "RUN_INDUCTOR_TESTS",
|
||||
defaultValue: true,
|
||||
description: "Run inductor codegen tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "RUN_ALL_UNIT_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run all unit tests (default: OFF)")
|
||||
}
|
||||
environment{
|
||||
dbuser = "${dbuser}"
|
||||
@@ -1000,7 +1050,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CODEGEN_TESTS.toBoolean() }
|
||||
expression { params.RUN_CODEGEN_TESTS.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
environment{
|
||||
@@ -1147,6 +1197,62 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run TILE_ENGINE_GEMM Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run TILE_ENGINE_GEMM Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j64 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Run TILE_ENGINE_GEMM Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j128 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j128 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage("Build CK and run Tests")
|
||||
{
|
||||
@@ -1190,11 +1296,11 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK for all gfx9 targets")
|
||||
stage("Build CK and run Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { (params.BUILD_GFX942.toBoolean() || params.RUN_FULL_QA.toBoolean()) && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
@@ -1205,6 +1311,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1212,6 +1319,29 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on gfx950")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.BUILD_GFX950.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx950") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on gfx908")
|
||||
{
|
||||
when {
|
||||
@@ -1225,6 +1355,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1236,7 +1367,7 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { params.BUILD_GFX90A.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
@@ -1245,6 +1376,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx90a" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1252,7 +1384,7 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK instances for different targets")
|
||||
stage("Build CK instances for all supported targets")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
@@ -1263,8 +1395,7 @@ pipeline {
|
||||
execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1151;gfx1201" \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
|
||||
-D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1275,15 +1406,16 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { params.BUILD_GFX10.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1030") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx10-3-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx1030" \
|
||||
-DGPU_TARGETS="gfx10-3-generic" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1295,15 +1427,16 @@ pipeline {
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx1101") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx1101" \
|
||||
-DGPU_TARGETS="gfx11-generic" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1319,11 +1452,12 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx1201") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1201" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx1201" \
|
||||
-DGPU_TARGETS="gfx12-generic" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
|
||||
4
client_example/32_gemm_mx/CMakeLists.txt
Normal file
4
client_example/32_gemm_mx/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx950")
|
||||
add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp)
|
||||
target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
330
client_example/32_gemm_mx/gemm_mx_fp8.cpp
Normal file
330
client_example/32_gemm_mx/gemm_mx_fp8.cpp
Normal file
@@ -0,0 +1,330 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
template <typename X, typename Y>
|
||||
inline constexpr bool is_same_v = ck::is_same<X, Y>::value;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AScaleLayout = Row;
|
||||
using BScaleLayout = Col;
|
||||
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
mem_size_ = mem_size;
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
std::size_t mem_size_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
/* Require by mx type*/
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
N = std::stoi(argv[2]);
|
||||
K = std::stoi(argv[3]);
|
||||
|
||||
StrideA = std::stoi(argv[4]);
|
||||
StrideB = std::stoi(argv[5]);
|
||||
StrideC = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_matrix_space_size =
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if constexpr(std::is_same<Layout, Row>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (nCol - 1) * stride + nRow;
|
||||
}
|
||||
};
|
||||
|
||||
/* Scale stride Calculation */
|
||||
auto f_get_default_stride =
|
||||
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
return static_cast<ck::index_t>(col);
|
||||
else
|
||||
return static_cast<ck::index_t>(row);
|
||||
}
|
||||
else
|
||||
return static_cast<ck::index_t>(stride);
|
||||
};
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize;
|
||||
auto Scale_Stride_AM =
|
||||
f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{});
|
||||
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
|
||||
|
||||
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
|
||||
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
|
||||
SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
|
||||
SimpleDeviceMem a_scale_device_buf(
|
||||
sizeof(XDataType) *
|
||||
f_matrix_space_size(Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
|
||||
SimpleDeviceMem b_scale_device_buf(
|
||||
sizeof(XDataType) *
|
||||
f_matrix_space_size(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
|
||||
|
||||
using DeviceOp =
|
||||
ck::tensor_operation::device::DeviceGemmMX<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
XPackedDataType,
|
||||
BDataType,
|
||||
XPackedDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop =
|
||||
std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
|
||||
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> +
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * M * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * N * K / ScaleBlockSize;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
if(found)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -32,7 +32,7 @@ if (DTYPES)
|
||||
add_definitions(-DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
message(DEBUG "DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
|
||||
@@ -14,8 +14,10 @@ cd client_example/build
|
||||
cmake \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \
|
||||
-D GPU_TARGETS="gfx908;gfx90a" \
|
||||
..
|
||||
```
|
||||
You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s).
|
||||
|
||||
### Build client example
|
||||
```bash
|
||||
|
||||
@@ -66,7 +66,8 @@ else()
|
||||
-Wunreachable-code
|
||||
-Wunused
|
||||
-Wno-reserved-identifier
|
||||
-Werror
|
||||
# Werror set outside by BUILD_DEV
|
||||
# -Werror
|
||||
-Wno-option-ignored
|
||||
-Wsign-compare
|
||||
-Wno-extra-semi-stmt
|
||||
@@ -108,7 +109,7 @@ else()
|
||||
endif()
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Wno-missing-field-initializers
|
||||
-Wno-deprecated-declarations
|
||||
-Wno-error=deprecated-declarations
|
||||
)
|
||||
endif()
|
||||
add_definitions(${CMAKE_COMPILER_WARNINGS})
|
||||
|
||||
116
cmake/ShardInstantiation.cmake
Normal file
116
cmake/ShardInstantiation.cmake
Normal file
@@ -0,0 +1,116 @@
|
||||
# Function to generate templated instantiation functions and caller function.
|
||||
|
||||
# In order to reduce build times, we split the instantiation of template functions into multiple files.
|
||||
# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions,
|
||||
# which can be placed the TEMPLATE_FILE (typically a .in file).
|
||||
|
||||
# This CMake function generates the instantiation functions and a caller function that calls all the instantiation
|
||||
# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of
|
||||
# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard,
|
||||
# and generates a caller function that calls all the instantiation functions.
|
||||
|
||||
# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation
|
||||
# of the template functions in the caller function, and that code is automatically generated by this function.
|
||||
|
||||
# In addition to the user-supplied template, this CMake function uses two generic templates:
|
||||
#
|
||||
# 1. `instantiate_shard.in`: This is the template for the instantiation functions.
|
||||
# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions.
|
||||
|
||||
# This function takes the following arguments:
|
||||
#
|
||||
# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`).
|
||||
# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions.
|
||||
# - NUM_SHARDS: The number of shards to generate.
|
||||
# - OUTPUT_DIR: The build directory where the generated source files will be placed.
|
||||
# - SRC_LIST: The list of source files to which the generated source files will be added.
|
||||
|
||||
|
||||
function(generate_sharded_instantiations)
|
||||
cmake_parse_arguments(
|
||||
GEN_SHARDED
|
||||
# No boolean arguments
|
||||
""
|
||||
# Single-value arguments
|
||||
"INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST"
|
||||
# No multi-value arguments.
|
||||
""
|
||||
${ARGN}
|
||||
)
|
||||
if (NOT GEN_SHARDED_INSTANCES_NAME)
|
||||
message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
if (NOT GEN_SHARDED_TEMPLATE_FILE)
|
||||
message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
if (NOT GEN_SHARDED_NUM_SHARDS)
|
||||
message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
if(NOT GEN_SHARDED_OUTPUT_DIR)
|
||||
message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
if (NOT GEN_SHARDED_SRC_LIST)
|
||||
message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations")
|
||||
endif()
|
||||
|
||||
file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR})
|
||||
|
||||
|
||||
set(GENERATED_SOURCE_FILES "")
|
||||
set(EXTERN_TEMPLATE_STATEMENTS "")
|
||||
set(CALL_STATEMENTS "")
|
||||
message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}")
|
||||
|
||||
set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}")
|
||||
|
||||
# Generate the inc file with the template function defintions.
|
||||
# This include file will hold the template function definitions and a using alias for all the shard
|
||||
# instantiation functions.
|
||||
configure_file(
|
||||
"${GEN_SHARDED_TEMPLATE_FILE}"
|
||||
"${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc"
|
||||
@ONLY
|
||||
)
|
||||
|
||||
# Generate the sharded instantiation functions.
|
||||
# This is where the build parallelization happens.
|
||||
# Each of these source files will contain a single instantiation function for a shard,
|
||||
# which will be called sequentially by the caller function.
|
||||
set(INC_DIR "${GEN_SHARDED_INC_DIR}")
|
||||
math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1")
|
||||
foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID})
|
||||
set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}")
|
||||
set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp")
|
||||
set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in")
|
||||
configure_file(
|
||||
"${SHARD_FUNCTION_TEMPLATE}"
|
||||
"${SHARD_FUNCTION_PATH}"
|
||||
@ONLY
|
||||
)
|
||||
list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}")
|
||||
set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>")
|
||||
list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)")
|
||||
list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)")
|
||||
endforeach()
|
||||
|
||||
# Join the include statements, the extern template declarations, and the call statements each
|
||||
# into a single string for variable substitution in the caller function.
|
||||
string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}")
|
||||
string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}")
|
||||
string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}")
|
||||
|
||||
# Generate the caller function.
|
||||
set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp")
|
||||
set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in")
|
||||
configure_file(
|
||||
"${FUNCTION_TEMPLATE}"
|
||||
"${CALLER_FUNCTION_PATH}"
|
||||
@ONLY
|
||||
)
|
||||
list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}")
|
||||
|
||||
# Add the generated source files to the list of source files.
|
||||
# This allows the generated source files to be included in the build.
|
||||
list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES})
|
||||
set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
15
cmake/call_shard.in
Normal file
15
cmake/call_shard.in
Normal file
@@ -0,0 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "@INSTANCES@.inc"
|
||||
|
||||
namespace ck::tensor_operation::device::instance {
|
||||
|
||||
@EXTERN_TEMPLATE_STATEMENTS@;
|
||||
|
||||
void add_@INSTANCES@(
|
||||
@INSTANCES@& instances) {
|
||||
@CALL_STATEMENTS@;
|
||||
}
|
||||
|
||||
} // namespace ck::tensor_operation::device::instance
|
||||
9
cmake/instantiate_shard.in
Normal file
9
cmake/instantiate_shard.in
Normal file
@@ -0,0 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "@INSTANCES@.inc"
|
||||
|
||||
namespace ck::tensor_operation::device::instance {
|
||||
template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>(
|
||||
@INSTANCES@& instances);
|
||||
} // namespace ck::tensor_operation::device::instance
|
||||
@@ -19,9 +19,7 @@ list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
|
||||
include(Embed)
|
||||
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
||||
${CK_ROOT}/include/ck/*.hpp)
|
||||
# printouts fot debug purposes
|
||||
# message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
|
||||
# message(STATUS "RELATIVE: ${CK_ROOT}/include")
|
||||
|
||||
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
|
||||
|
||||
add_compile_options(-std=c++17)
|
||||
@@ -48,6 +46,7 @@ rocm_install_targets(
|
||||
INCLUDE include
|
||||
)
|
||||
rocm_export_targets(
|
||||
TARGETS ck_host ck_headers
|
||||
EXPORT ck_host_targets
|
||||
NAMESPACE composable_kernel::
|
||||
)
|
||||
|
||||
@@ -8,5 +8,5 @@ target_link_libraries(ck_rtc PUBLIC -lstdc++fs)
|
||||
option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON)
|
||||
if(USE_HIPRTC_FOR_CODEGEN_TESTS)
|
||||
target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS)
|
||||
message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
|
||||
message(STATUS "CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
|
||||
endif()
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core[api_reference]==1.18.4
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
rocm-docs-core[api_reference]==1.20.1
|
||||
sphinxcontrib-bibtex==2.6.4
|
||||
|
||||
@@ -237,7 +237,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core[api-reference]==1.18.4
|
||||
rocm-docs-core[api-reference]==1.20.1
|
||||
# via -r requirements.in
|
||||
rpds-py==0.24.0
|
||||
# via
|
||||
@@ -278,7 +278,7 @@ sphinx-notfound-page==1.1.0
|
||||
# via rocm-docs-core
|
||||
sphinxcontrib-applehelp==2.0.0
|
||||
# via sphinx
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
sphinxcontrib-bibtex==2.6.4
|
||||
# via -r requirements.in
|
||||
sphinxcontrib-devhelp==2.0.0
|
||||
# via sphinx
|
||||
|
||||
19
example/01_gemm/CMakeLists.txt
Executable file → Normal file
19
example/01_gemm/CMakeLists.txt
Executable file → Normal file
@@ -39,6 +39,12 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3)
|
||||
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
|
||||
|
||||
set(GEMM_OPTIONS)
|
||||
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-16")
|
||||
example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS})
|
||||
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -109,3 +115,16 @@ add_example_executable(example_gemm_wmma_bf16 gemm_wmma_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16)
|
||||
add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_int8)
|
||||
|
||||
add_example_executable(example_gemm_wmma_bf16_v3 gemm_wmma_bf16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_v3)
|
||||
add_example_executable(example_gemm_wmma_bf16_pk_i4_v3 gemm_wmma_bf16_pk_i4_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_pk_i4_v3)
|
||||
add_example_executable(example_gemm_wmma_fp8_v3 gemm_wmma_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_v3 gemm_wmma_fp16_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3)
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
@@ -57,8 +59,9 @@ struct ProblemSizeStreamK_universal final
|
||||
ck::index_t StrideB = -1;
|
||||
ck::index_t StrideC = -1;
|
||||
|
||||
ck::index_t Grid_size = -1; // defaults to max occupancy
|
||||
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
|
||||
ck::index_t Grid_size = -1; // defaults to max occupancy
|
||||
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
|
||||
ck::StreamKReductionStrategy reduction_strategy = ck::StreamKReductionStrategy::Atomic;
|
||||
};
|
||||
|
||||
struct ProblemSizeSplitK final
|
||||
@@ -128,11 +131,12 @@ bool parse_cmd_args<ProblemSize>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
|
||||
std::cerr
|
||||
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -172,7 +176,19 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
|
||||
if(argc >= 11)
|
||||
{
|
||||
problem_size.Streamk_sel = std::stoi(argv[10]);
|
||||
problem_size.Grid_size = std::stoi(argv[11]);
|
||||
|
||||
if(argc >= 12)
|
||||
{
|
||||
problem_size.Grid_size = std::stoi(argv[11]);
|
||||
|
||||
if(argc >= 13)
|
||||
{
|
||||
int reduction_strategy = std::stoi(argv[12]);
|
||||
problem_size.reduction_strategy = reduction_strategy == 0
|
||||
? ck::StreamKReductionStrategy::Atomic
|
||||
: ck::StreamKReductionStrategy::Reduction;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -181,9 +197,12 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
|
||||
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
|
||||
<< std::endl
|
||||
<< "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
|
||||
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
|
||||
<< std::endl
|
||||
<< "arg11: Grid_size(-1 for max occupancy)" << std::endl
|
||||
<< "arg12: Reduction strategy (0: Atomic, 1: Reduction)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -227,13 +246,14 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
|
||||
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
|
||||
std::cerr
|
||||
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
|
||||
<< std::endl
|
||||
<< "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
|
||||
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -277,12 +297,13 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg10: KBatch" << std::endl;
|
||||
std::cerr
|
||||
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
|
||||
<< std::endl
|
||||
<< "arg10: KBatch" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
253
example/01_gemm/gemm_wmma_bf16_pk_i4_v3.cpp
Normal file
253
example/01_gemm/gemm_wmma_bf16_pk_i4_v3.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::pk_i4_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = true;
|
||||
static constexpr ck::index_t KPerBlock = 32;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256,
|
||||
128, 128, KPerBlock,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 1,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 1,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1,
|
||||
ADataType, ADataType, PermuteA, PermuteB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// weight permute
|
||||
if constexpr(PermuteB)
|
||||
{
|
||||
int K1 = KPerBlock;
|
||||
int K0 = K / KPerBlock;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j++)
|
||||
{
|
||||
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmV2Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_splitk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
47
example/01_gemm/gemm_wmma_bf16_v3.cpp
Normal file
47
example/01_gemm/gemm_wmma_bf16_v3.cpp
Normal file
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
256,
|
||||
128, 128, 32,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, 1,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, 1,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
52
example/01_gemm/gemm_wmma_fp16_fp8_v3.cpp
Normal file
52
example/01_gemm/gemm_wmma_fp16_fp8_v3.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256,
|
||||
128, 128, 32,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 1,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 1,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
302
example/01_gemm/gemm_wmma_fp16_pk_i4_v3.cpp
Normal file
302
example/01_gemm/gemm_wmma_fp16_pk_i4_v3.cpp
Normal file
@@ -0,0 +1,302 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::pk_i4_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = true;
|
||||
static constexpr ck::index_t KPerBlock = 32;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256,
|
||||
128, 128, KPerBlock,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 1,
|
||||
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 1,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1,
|
||||
ADataType, ADataType, PermuteA, PermuteB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// weight permute
|
||||
if constexpr(PermuteB)
|
||||
{
|
||||
int K1 = KPerBlock;
|
||||
int K0 = K / KPerBlock;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j++)
|
||||
{
|
||||
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int i4x2 = b_k_n_permute(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmV2Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_splitk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
47
example/01_gemm/gemm_wmma_fp16_v3.cpp
Normal file
47
example/01_gemm/gemm_wmma_fp16_v3.cpp
Normal file
@@ -0,0 +1,47 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
128,
|
||||
128, 64,
|
||||
64, 8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, 1,
|
||||
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, 1,
|
||||
1, 1, S<1, 32, 1, 4>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
67
example/01_gemm/gemm_wmma_fp8_v3.cpp
Normal file
67
example/01_gemm/gemm_wmma_fp8_v3.cpp
Normal file
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using CDataType = ck::bhalf_t;
|
||||
using ComputeTypeA = ck::f8_t;
|
||||
using ComputeTypeB = ck::f8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
128,
|
||||
128, 64, 64,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 1, S<1, 32, 1, 4>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
|
||||
ComputeTypeA, ComputeTypeB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceComputeType = ck::f8_t;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ReferenceComputeType,
|
||||
ReferenceComputeType>;
|
||||
|
||||
#include "run_gemm_example_v2.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
if(!ck::is_gfx12_supported())
|
||||
{
|
||||
std::cout << "This kernel support gfx12 only" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
return !run_gemm_splitk_example(argc, argv);
|
||||
}
|
||||
0
example/01_gemm/gemm_xdl_bf16.cpp
Executable file → Normal file
0
example/01_gemm/gemm_xdl_bf16.cpp
Executable file → Normal file
0
example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp
Executable file → Normal file
0
example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp
Executable file → Normal file
@@ -32,6 +32,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
// this instance has been tested working on gfx950
|
||||
// < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
0
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
Executable file → Normal file
0
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp
Executable file → Normal file
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>;
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
#else
|
||||
// clang-format off
|
||||
|
||||
@@ -33,7 +33,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
|
||||
@@ -36,7 +36,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
|
||||
@@ -21,6 +21,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
auto Grid_size = problem_size.Grid_size;
|
||||
auto Streamk_sel = problem_size.Streamk_sel;
|
||||
|
||||
auto reduction_strategy = problem_size.reduction_strategy;
|
||||
if(reduction_strategy == ck::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
std::cout << "Using Atomic reduction strategy" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Using Parallel reduction strategy" << std::endl;
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
@@ -35,7 +45,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
@@ -152,7 +162,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
Grid_size,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
c_element_op,
|
||||
reduction_strategy);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -242,7 +253,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
<< " GB/s, " << gemm.GetTypeString()
|
||||
<< (reduction_strategy == ck::StreamKReductionStrategy::Atomic ? " (Atomic)"
|
||||
: " (Reduction)")
|
||||
<< std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
if(stride == -1 || stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -34,7 +34,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
|
||||
@@ -141,8 +141,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
|
||||
a_tensors_device.reserve(group_count);
|
||||
b_tensors_device.reserve(group_count);
|
||||
d_tensors_device.reserve(group_count);
|
||||
c_tensors_device.reserve(group_count);
|
||||
d_tensors_device.resize(group_count); // reserve and update vector size
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
@@ -71,9 +71,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
256, // BlockSize
|
||||
256, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
64, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
@@ -84,14 +84,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
0, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
0, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
|
||||
// clang-format on
|
||||
|
||||
#else
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_fp16_bpreshuffle gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp)
|
||||
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
|
||||
set(EXAMPLE_COMPILE_OPTIONS)
|
||||
# Open it when SGBPack branch landed on mainline
|
||||
# list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental")
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950)
|
||||
set(target 0)
|
||||
@@ -19,9 +28,38 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(HAS_MAX_ILP_SCHEDULING_STRATEGY)
|
||||
list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1)
|
||||
endif()
|
||||
target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
endif()
|
||||
set(GEMM_OPTIONS)
|
||||
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(GEMM_OPTIONS)
|
||||
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
set(BLOCKSCALE_GEMM_OPTIONS)
|
||||
check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP)
|
||||
check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION)
|
||||
if(HAS_MISCHED_BOTTOMUP)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1")
|
||||
elseif(HAS_MISCHED_PRERA_DIRECTION)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup")
|
||||
endif()
|
||||
check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL)
|
||||
if(HAS_MAX_OCCUPANCY_EXPERIMENTAL)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental)
|
||||
endif()
|
||||
# list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1")
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
|
||||
example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
16, 128,
|
||||
256, 16, 16,
|
||||
128, 128,
|
||||
128, 16, 16,
|
||||
16, 16,
|
||||
1, 2,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 16, 1, 16>, S<8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
1, 2, S<1, 32, 1, 8>, S<8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
@@ -0,0 +1,372 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using FP8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = FP8;
|
||||
using A1DataType = F32;
|
||||
using B0DataType = FP8;
|
||||
using B1DataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using A1Layout = Col;
|
||||
using B0Layout = Col;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
int K0 = K / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle
|
||||
// clang-format off
|
||||
<Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
128, 128,
|
||||
128, 16, 16,
|
||||
16, 16,
|
||||
8, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 1, S<1, 32, 1, 8>, S<8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
bool flush_cache = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 128;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 8)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
flush_cache = std::stoi(argv[7]);
|
||||
|
||||
StrideA = K;
|
||||
StrideB = K;
|
||||
StrideE = N;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: M, N, K\n");
|
||||
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// Transpose the AScale tensor for better performance
|
||||
ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<A1DataType> a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M,
|
||||
(K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
Scale_Stride_AK,
|
||||
A1Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<B0DataType> b0_preshuffled(
|
||||
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
|
||||
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N,
|
||||
Scale_Stride_BN,
|
||||
B0Layout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_m_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{},
|
||||
StrideE,
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float ave_time = 0.0f;
|
||||
|
||||
if(flush_cache)
|
||||
{
|
||||
int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype;
|
||||
|
||||
ave_time = invoker.Run(argument,
|
||||
StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100});
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<AccDataType> c_m_n({M, N});
|
||||
Tensor<float> a_m_k({M, K});
|
||||
Tensor<float> b_k_n({K, N});
|
||||
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
a_m_k(m, k) = ck::type_convert<float>(a0_m_k(m, k)) *
|
||||
a1_m_k(m / Scale_Block_M, k / Scale_Block_K);
|
||||
}
|
||||
}
|
||||
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
b_k_n(k, n) = ck::type_convert<float>(b0_k_n(k, n)) *
|
||||
b1_k_n(k / Scale_Block_K, n / Scale_Block_N);
|
||||
}
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<float,
|
||||
float,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
#if 1
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_m_n_host_result(m, n) = ck::type_convert<EDataType>(c_m_n(m, n));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -139,10 +139,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
|
||||
128, 128, 128,
|
||||
256, 256, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
8, 2,
|
||||
16, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
|
||||
|
||||
@@ -158,21 +158,22 @@ using BElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 64;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
|
||||
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
|
||||
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = false;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
@@ -183,15 +184,15 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM
|
||||
// mn_perxdl
|
||||
MNPerXDL, MNPerXDL,
|
||||
// mn_xdlperwave
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>;
|
||||
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -205,9 +206,9 @@ int main(int argc, char* argv[])
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t tokens = 64;
|
||||
ck::index_t sorted_tile_num = 256;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
ck::index_t tokens = 16384;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
@@ -263,11 +264,12 @@ int main(int argc, char* argv[])
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
max_token_id.mData = {valid_size};
|
||||
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
expert_ids.mData[i] = i / (valid_tile_num / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
|
||||
@@ -307,7 +309,7 @@ int main(int argc, char* argv[])
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.1, 0.1});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
|
||||
@@ -0,0 +1,548 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
using I64 = int64_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F8;
|
||||
using A1DataType = F32;
|
||||
using B0DataType = F8;
|
||||
using B1DataType = F32;
|
||||
// using EDataType = F16;
|
||||
using EDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D2Layout>;
|
||||
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D2>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d2);
|
||||
}
|
||||
};
|
||||
|
||||
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16 / sizeof(B0DataType);
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
int K0 = K / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * static_cast<I64>(K) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
#if 0
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
|
||||
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
|
||||
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
|
||||
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
//threadnum, mblock, nblock, kblock
|
||||
BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
// ak1, bk1
|
||||
AK1, BK1,
|
||||
// mn_perxdl
|
||||
MNPerXDL, MNPerXDL,
|
||||
// mn_xdlperwave
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t topk = 2;
|
||||
// ck::index_t sorted_tile_num = 515;
|
||||
// ck::index_t valid_tile_num = 512;
|
||||
// ck::index_t tokens = 8192;
|
||||
// ck::index_t sorted_tile_num = 15;
|
||||
// ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_tile_num = 259;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
ck::index_t tokens = 4096;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
ck::index_t K = 7168;
|
||||
ck::index_t experts = 256;
|
||||
ck::index_t topk = 8;
|
||||
ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 261;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
#endif
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
sorted_tile_num = std::stoi(argv[7]);
|
||||
valid_tile_num = std::stoi(argv[8]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
|
||||
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
max_token_id.mData = {valid_size};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<A1DataType> a1_t_k(HostTensorDescriptor(
|
||||
{tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<B1DataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts,
|
||||
(K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N * 2},
|
||||
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
|
||||
sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_t_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
preShuffleBuffer(
|
||||
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
|
||||
sizeof(B0DataType) * K * N * 2 * experts +
|
||||
sizeof(EDataType) * valid_tile_num * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s.\n"
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> a_t_k({tokens, K});
|
||||
Tensor<float> b_e_n_k({experts, K, N * 2});
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
|
||||
// handle scale before ref.
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
a_t_k(t, k) = ck::type_convert<float>(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K);
|
||||
}
|
||||
}
|
||||
|
||||
for(int e = 0; e < experts; ++e)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N * 2; ++n)
|
||||
{
|
||||
b_e_n_k(e, k, n) = ck::type_convert<float>(b0_e_n_k(e, k, n)) *
|
||||
b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N);
|
||||
}
|
||||
}
|
||||
}
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm1BlockScale<float,
|
||||
float,
|
||||
float,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ActOP,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a_t_k,
|
||||
b_e_n_k,
|
||||
d2_e_n,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < valid_size; ++m)
|
||||
{
|
||||
|
||||
const int fuse_t = sorted_token_ids.mData[m];
|
||||
const int t = fuse_t & 0xffffff;
|
||||
const int topk_id = (fuse_t & 0xff000000) >> 24;
|
||||
|
||||
if(t >= tokens)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, topk_id, n) =
|
||||
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -123,11 +123,11 @@ using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t MPerBlock = 256;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t MXDLPerWave = 4;
|
||||
static constexpr ck::index_t MXDLPerWave = 16;
|
||||
static constexpr ck::index_t NXDLPerWave = 4;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t NPerBlock = 256;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
|
||||
@@ -139,6 +139,7 @@ static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
static constexpr bool PerTokenQuant = true;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
|
||||
// clang-format off
|
||||
@@ -164,12 +165,12 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>;
|
||||
// kernel 2: 128->32x128x128
|
||||
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
|
||||
|
||||
@@ -186,18 +187,18 @@ int main(int argc, char* argv[])
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t sorted_tile_num = 16;
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_tile_num = 133;
|
||||
ck::index_t valid_tile_num = 128;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 128;
|
||||
ck::index_t tokens = 16384;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 3)
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
@@ -238,20 +239,22 @@ int main(int argc, char* argv[])
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
constexpr auto StrideDs = PerTokenQuant ? std::array<ck::index_t, NumDTensor>{1, 1, 0}
|
||||
: std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
|
||||
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
|
||||
|
||||
// max_token_id.mData[0] = valid_size;
|
||||
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
|
||||
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
@@ -278,8 +281,10 @@ int main(int argc, char* argv[])
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<D0DataType> d0_t_n(
|
||||
HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}));
|
||||
Tensor<D1DataType> d1_e_n(
|
||||
HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
@@ -0,0 +1,541 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
using I64 = int64_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F8;
|
||||
using A1DataType = F32;
|
||||
using B0DataType = F8;
|
||||
using B1DataType = F32;
|
||||
using EDataType = F16;
|
||||
// using EDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
|
||||
using DsLayout = ck::Tuple<D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D2>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
|
||||
// for real kernel use
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for real kernel use
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& e, const float& c, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d2);
|
||||
}
|
||||
};
|
||||
|
||||
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16 / sizeof(B0DataType);
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
int K0 = K / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * static_cast<I64>(K) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
#if 0
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t MXDLPerWave = 2;
|
||||
static constexpr ck::index_t NXDLPerWave = 2;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
|
||||
|
||||
static constexpr ck::index_t CShuffleNLane = 16;
|
||||
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 2;
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
static constexpr ck::index_t D2Vec = 1;
|
||||
|
||||
// clang-format off
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
AK1, BK1,
|
||||
MNPerXDL, MNPerXDL,
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
|
||||
|
||||
#else
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// tokens = 1
|
||||
// topk = 1
|
||||
// experts = 8
|
||||
// per expert:
|
||||
|
||||
constexpr ck::index_t valid_tile_num =
|
||||
26; // 13 for 128; 52 for 32; 4096 for ds // > token * topk / MPerBlock
|
||||
constexpr ck::index_t sorted_tile_num = valid_tile_num + 3;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
ck::index_t K = 7160;
|
||||
ck::index_t experts = 256;
|
||||
ck::index_t tokens = 1;
|
||||
ck::index_t topk = 8;
|
||||
#endif
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
|
||||
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
|
||||
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
|
||||
// int eids[] = {0, 1, 3, 3, 3};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
// int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
// 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
// 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
|
||||
// 3, 3, 3, 3, 3, 3, 3, 3, 4, 4,
|
||||
// 5, 5, 5, 5, 6, 6, 6, 6, 7, 7,
|
||||
// 7, 7,
|
||||
// 3, 3, 3};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<A1DataType> a1_t_k_k(
|
||||
HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K},
|
||||
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
|
||||
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<B1DataType> b1_e_n_k(HostTensorDescriptor(
|
||||
{experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N},
|
||||
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{1.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{1.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{1.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{1.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{1.0, 1.0});
|
||||
for(auto i = 0; i < N * K; i++)
|
||||
{
|
||||
b0_e_n_k.mData[i] = ck::type_convert<B0DataType>(static_cast<float>(0.1));
|
||||
b0_e_n_k.mData[i + N * K] = ck::type_convert<B0DataType>(static_cast<float>(0.2));
|
||||
}
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
|
||||
sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_t_k_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl);
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * tokens * K * topk +
|
||||
sizeof(B0DataType) * K * N * experts +
|
||||
sizeof(EDataType) * tokens * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s.\n"
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> a_t_k_k({tokens, topk, K});
|
||||
Tensor<float> b_e_n_k({experts, K, N});
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int tk = 0; tk < topk; ++tk)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
a_t_k_k(t, tk, k) = ck::type_convert<float>(a0_t_k_k(t, tk, k)) *
|
||||
a1_t_k_k(t, tk, k / Scale_Block_K);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(int e = 0; e < experts; ++e)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
b_e_n_k(e, k, n) = ck::type_convert<float>(b0_e_n_k(e, k, n)) *
|
||||
b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
|
||||
float,
|
||||
float,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a_t_k_k,
|
||||
b_e_n_k,
|
||||
d2_e_n,
|
||||
c_t_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
0
example/66_complex_contraction_bilinear/CMakeLists.txt
Executable file → Normal file
0
example/66_complex_contraction_bilinear/CMakeLists.txt
Executable file → Normal file
0
example/66_complex_contraction_bilinear/README.md
Executable file → Normal file
0
example/66_complex_contraction_bilinear/README.md
Executable file → Normal file
0
example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp
Executable file → Normal file
0
example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp
Executable file → Normal file
0
example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp
Executable file → Normal file
0
example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp
Executable file → Normal file
@@ -6,6 +6,33 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
|
||||
add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
|
||||
# TODO: Fix RRR
|
||||
# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
|
||||
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
|
||||
|
||||
add_example_executable(example_gemm_mx_fp4_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4_bpreshuffle)
|
||||
|
||||
add_example_executable(example_moe_gemm1_xdl_mx_fp4_bns moe_gemm1_xdl_mx_fp4_bns.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns)
|
||||
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns)
|
||||
|
||||
set(FP4_MXGEMM_OPTIONS)
|
||||
list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1")
|
||||
example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
set(FP8_MXGEMM_OPTIONS)
|
||||
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
|
||||
@@ -21,11 +21,11 @@ using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 128;
|
||||
constexpr ck::index_t KPerBlock = 256;
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
@@ -45,32 +45,32 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
128, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
16, // NPerBlock
|
||||
32, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
1, // NXdlPerWave
|
||||
S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
2, // NXdlPerWave
|
||||
S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<16, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
false, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 16, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
2, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
@@ -83,6 +83,7 @@ int main(int argc, char* argv[])
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -23,8 +23,9 @@
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using MFMA = ck::tensor_layout::gemm::MFMA;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
@@ -36,6 +37,8 @@ struct ExecutionConfig final
|
||||
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
|
||||
bool time_kernel = false; // (0=no, 1=yes)
|
||||
int verbosity = 0; // (0=no info, 1=verbose info)
|
||||
int warm_up = 10;
|
||||
int repeat = 10;
|
||||
};
|
||||
|
||||
struct ProblemSizeSplitK final
|
||||
@@ -86,6 +89,8 @@ bool parse_cmd_args(int argc,
|
||||
if(argc >= 12)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[11]);
|
||||
config.warm_up = std::stoi(argv[12]);
|
||||
config.repeat = std::stoi(argv[13]);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -103,10 +108,90 @@ bool parse_cmd_args(int argc,
|
||||
return true;
|
||||
}
|
||||
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
|
||||
// 2-k)));
|
||||
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceOpInstance,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename XDataType,
|
||||
typename XPackedDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -119,6 +204,8 @@ template <typename DeviceOpInstance,
|
||||
ck::index_t ScaleBlockSize>
|
||||
bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
constexpr bool BPreShuffle = ck::is_same_v<BLayout, MFMA>;
|
||||
using BRefLayout = ck::conditional_t<BPreShuffle, Col, BLayout>;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
@@ -131,28 +218,19 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
auto f_host_tensor_descriptor =
|
||||
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<ck::index_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<ck::index_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<ck::index_t>(stride);
|
||||
@@ -172,16 +250,30 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
using AScaleLayout = Row;
|
||||
using BScaleLayout = Col;
|
||||
|
||||
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
|
||||
auto Scale_Padded_M = ck::math::integer_least_multiple(M, ScaleBlockSize);
|
||||
auto Scale_Stride_AM =
|
||||
f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{});
|
||||
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
auto b_k_n =
|
||||
std::make_shared<Tensor<BDataType>>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{}));
|
||||
auto b_input = b_k_n;
|
||||
if constexpr(BPreShuffle)
|
||||
b_input = std::make_shared<Tensor<BDataType>>(
|
||||
f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size
|
||||
|
||||
// scales for A and B
|
||||
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
|
||||
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
|
||||
Tensor<XDataType> b_k_n_scale(f_host_tensor_descriptor(
|
||||
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
|
||||
Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
|
||||
Tensor<XDataType> b_k_n_scale(
|
||||
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
|
||||
|
||||
// shuffled scales for A and B
|
||||
Tensor<XDataType> a_shuffled_scale(f_host_tensor_descriptor(
|
||||
Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
|
||||
Tensor<XDataType> b_shuffled_scale(
|
||||
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification
|
||||
@@ -192,18 +284,33 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
{
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n->mDesc << std::endl;
|
||||
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
|
||||
std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl;
|
||||
}
|
||||
|
||||
auto a_data_element = [](float x) {
|
||||
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
|
||||
return ck::type_convert<ADataType>(ck::float2_t(x));
|
||||
else
|
||||
return ck::type_convert<ADataType>(x);
|
||||
};
|
||||
auto b_data_element = [](float x) {
|
||||
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
|
||||
return ck::type_convert<BDataType>(ck::float2_t(x));
|
||||
else
|
||||
return ck::type_convert<BDataType>(x);
|
||||
};
|
||||
|
||||
using int_distr = std::uniform_int_distribution<int>;
|
||||
using float_distr = std::uniform_real_distribution<float>;
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: // Initializations for development and debugging
|
||||
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
|
||||
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(0.5f)}(b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
|
||||
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
|
||||
ck::utils::FillConstant<BDataType>{b_data_element(2.0f)}(*b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Init A = {1}" << std::endl;
|
||||
@@ -215,31 +322,19 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
break;
|
||||
|
||||
case 1:
|
||||
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
|
||||
|
||||
if constexpr(ck::is_same_v<XDataType, ck::e8m0_bexp_t>)
|
||||
{
|
||||
a_m_k_scale.GenerateTensorValue(
|
||||
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorValue(
|
||||
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
}
|
||||
else
|
||||
{
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
|
||||
}
|
||||
|
||||
a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
|
||||
b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
|
||||
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
|
||||
a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2}
|
||||
break;
|
||||
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
|
||||
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
|
||||
a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0});
|
||||
a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
|
||||
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
|
||||
b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0});
|
||||
b_k_n_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -249,20 +344,33 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
}
|
||||
}
|
||||
|
||||
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
|
||||
a_shuffled_scale.mData.data(),
|
||||
Scale_Padded_M,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<BRefLayout, Col>>(
|
||||
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
|
||||
if constexpr(BPreShuffle)
|
||||
{
|
||||
int NPerXdl = 16; // Fixed 16
|
||||
preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl);
|
||||
}
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Device memory allocation..." << std::endl;
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize());
|
||||
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize());
|
||||
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize());
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Upload data to device..." << std::endl;
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
|
||||
a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data());
|
||||
b_device_buf.ToDevice(b_input->mData.data());
|
||||
b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data());
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Done." << std::endl;
|
||||
|
||||
@@ -275,9 +383,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(a_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(b_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
@@ -299,13 +407,26 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
"not consistent with the supported device_gemm arguments.");
|
||||
}
|
||||
|
||||
std::size_t total_size =
|
||||
a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() +
|
||||
a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() +
|
||||
a_shuffled_scale.GetElementSpaceSizeInBytes() +
|
||||
b_shuffled_scale.GetElementSpaceSizeInBytes();
|
||||
const auto total_cnt = ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size);
|
||||
const int rotating_count = std::max(1, std::min(config.repeat, static_cast<int>(total_cnt)));
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
|
||||
float ave_time = invoker.Run(argument,
|
||||
StreamConfig{nullptr,
|
||||
config.time_kernel,
|
||||
config.verbosity,
|
||||
config.warm_up,
|
||||
config.repeat,
|
||||
rotating_count > 1,
|
||||
rotating_count});
|
||||
|
||||
bool res_verified = true;
|
||||
if(config.do_verification > 0)
|
||||
@@ -332,7 +453,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
|
||||
a_m_k_scale,
|
||||
b_k_n,
|
||||
*b_k_n,
|
||||
b_k_n_scale,
|
||||
c_m_n_host_result,
|
||||
PassThrough{},
|
||||
@@ -347,20 +468,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
std::cout << "Comparing results..." << std::endl;
|
||||
}
|
||||
|
||||
if(config.init_method == 0)
|
||||
{
|
||||
auto expected = static_cast<float>(K);
|
||||
auto computed = type_convert<float>(c_m_n_device_result(1, 12));
|
||||
|
||||
res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
|
||||
std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
|
||||
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!");
|
||||
res_verified =
|
||||
res_verified &&
|
||||
ck::utils::check_err(
|
||||
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1);
|
||||
|
||||
if(config.verbosity > 0 && res_verified)
|
||||
std::cout << "Verification Successful!" << std::endl;
|
||||
@@ -377,13 +488,14 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
|
||||
// partial sums(K/ScaleBlockSize)]
|
||||
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
|
||||
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
|
||||
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> + sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
float gb_per_sec = static_cast<float>(num_btype) / 1e6f / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
@@ -396,6 +508,7 @@ template <typename DeviceOpInstance,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename XDataType,
|
||||
typename XPackedDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -416,6 +529,7 @@ bool run_mx_gemm_example(int argc, char* argv[])
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XPackedDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
103
example/67_gemm_microscaling/gemm_mx_fp4.cpp
Normal file
103
example/67_gemm_microscaling/gemm_mx_fp4.cpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f4x2_pk_t;
|
||||
using BDataType = ck::f4x2_pk_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
// AB DataType: f4x2_pk_t
|
||||
// Mathmatically, all numbers are represented as f4x2.
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XPackedDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XPackedDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
256, // MPerBlock
|
||||
256, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
8, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XPackedDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
103
example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp
Normal file
103
example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp
Normal file
@@ -0,0 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gemm_mx_common.hpp"
|
||||
|
||||
using ADataType = ck::f4x2_pk_t;
|
||||
using BDataType = ck::f4x2_pk_t;
|
||||
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = MFMA;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
// AB DataType: f4x2_pk_t
|
||||
// Mathmatically, all numbers are represented as f4x2.
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XPackedDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XPackedDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
512, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
8, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
4, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 8, 1, 32>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlockW
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return run_mx_gemm_example<DeviceOpInstance,
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XPackedDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ScaleBlockSize>(argc, argv)
|
||||
? 0
|
||||
: -1;
|
||||
}
|
||||
@@ -25,7 +25,7 @@ constexpr ck::index_t KPerBlock = 256;
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
@@ -49,26 +49,26 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
false, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
true, // BBlockLdsExtraN
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
@@ -83,6 +83,7 @@ int main(int argc, char* argv[])
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -24,7 +24,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
@@ -43,30 +43,30 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
|
||||
GemmSpec, // GemmSpec
|
||||
ScaleBlockSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
256, // MPerBlock
|
||||
256, // NPerBlock
|
||||
128, // KPerBlock
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
256, // KPerBlock
|
||||
16, // AK1
|
||||
8, // BK1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
8, // MXdlPerWave
|
||||
8, // NXdlPerWave
|
||||
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
4, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<32, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
false, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
@@ -82,6 +82,7 @@ int main(int argc, char* argv[])
|
||||
ADataType,
|
||||
BDataType,
|
||||
XDataType,
|
||||
XDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
|
||||
545
example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp
Normal file
545
example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp
Normal file
@@ -0,0 +1,545 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t NPerBlock = 64;
|
||||
static constexpr ck::index_t BlockSize = 256;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, BlockSize,
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
|
||||
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
|
||||
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_k_n_host_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_k_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
|
||||
e_t_k_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{0.5f});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.5f});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k(token_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A/B scale shuffle
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
|
||||
b_scale_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop =
|
||||
// FMA * tokens * N * (Gate+Up) * topk * K +
|
||||
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
|
||||
std::size_t(2) * tokens * N * 2 * topk * K +
|
||||
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
|
||||
sizeof(B0DataType) / 2 * K * N * 2 * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
|
||||
sizeof(EDataType) * tokens * topk * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<CShuffleDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
CShuffleDataType,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ActOP,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
a1_t_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < valid_size; ++m)
|
||||
{
|
||||
const int fuse_t = sorted_token_ids.mData[m];
|
||||
const int t = fuse_t & 0xffffff;
|
||||
const int topk_id = (fuse_t & 0xff000000) >> 24;
|
||||
|
||||
if(t >= tokens)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_k_n_host_result(t, topk_id, n) =
|
||||
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
526
example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp
Normal file
526
example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp
Normal file
@@ -0,0 +1,526 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
|
||||
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
int eids[sorted_tile_num]{};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
if(i < valid_tile_num)
|
||||
{
|
||||
eids[i] = (i * experts) / valid_tile_num;
|
||||
}
|
||||
else
|
||||
{
|
||||
eids[i] = 3;
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<XDataType> a1_t_k_k(
|
||||
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
|
||||
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
// B preshuffle
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
|
||||
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// FMA * tokens * N * topk * K +
|
||||
// FMA * tokens * N * topk * (K/BlockScale)
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
|
||||
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<CShuffleDataType> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
D2DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp,
|
||||
MulRoutedWeight,
|
||||
float,
|
||||
float>;
|
||||
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k_k,
|
||||
a1_t_k_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n, // topk weights
|
||||
c_t_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -20,7 +20,7 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
|
||||
endfunction(add_example_dependencies EXAMPLE_NAME)
|
||||
|
||||
function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
message(DEBUG "adding example ${EXAMPLE_NAME}")
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
@@ -47,7 +47,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
message("removing example source file ${source} ")
|
||||
message(DEBUG "removing example source file ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -58,70 +58,72 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl example ${source} ")
|
||||
message(DEBUG "removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any DPP examples if DPP_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
|
||||
message("removing dpp example ${source} ")
|
||||
message(DEBUG "removing dpp example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl example ${source} ")
|
||||
message(DEBUG "removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
message(DEBUG "removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any microscaling examples if gfx950 target is not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx")
|
||||
message("removing microscaling example ${source} ")
|
||||
message(DEBUG "removing microscaling example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
|
||||
message("removing fp8 example ${source} ")
|
||||
message(DEBUG "removing fp8 example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8")
|
||||
message("removing bf8 example ${source} ")
|
||||
message(DEBUG "removing bf8 example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
|
||||
# Build fp8 gemm_multiply_multiply and moe only on gfx94/95
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply_xdl_fp8_bpreshuffle")
|
||||
message("Skipping ${source} example for current target")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95")
|
||||
if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)")
|
||||
message(DEBUG "Skipping ${source} example for current target")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950
|
||||
message("trimming targets for ${FILE_NAME}")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
message(DEBUG "trimming targets for ${FILE_NAME}")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
@@ -133,13 +135,11 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 0)
|
||||
endif()
|
||||
#message("add_example returns ${result}")
|
||||
message(DEBUG "add_example returns ${result}")
|
||||
if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES)
|
||||
#message("adding to SMOKE EXAMPLE FILTER ${EXAMPLE_NAME}")
|
||||
set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST")
|
||||
add_dependencies(smoke ${EXAMPLE_NAME})
|
||||
elseif(result EQUAL 0 AND "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES)
|
||||
#message("Adding to REGRESSION EXAMPLE FILTER ${EXAMPLE_NAME}")
|
||||
set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "REGRESSION_TEST")
|
||||
add_dependencies(regression ${EXAMPLE_NAME})
|
||||
endif()
|
||||
@@ -153,7 +153,7 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
|
||||
endfunction(add_example_dependencies EXAMPLE_NAME)
|
||||
|
||||
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
message(DEBUG "adding example ${EXAMPLE_NAME}")
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
@@ -180,7 +180,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
message("removing example ${source} ")
|
||||
message(DEBUG "removing example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -191,28 +191,28 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl example ${source} ")
|
||||
message(DEBUG "removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
message("removing xdl example ${source} ")
|
||||
message(DEBUG "removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
message(DEBUG "removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
endif()
|
||||
@@ -224,12 +224,18 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 0)
|
||||
endif()
|
||||
|
||||
#message("add_example returns ${result}")
|
||||
|
||||
message(DEBUG "add_example returns ${result}")
|
||||
set(result ${result} PARENT_SCOPE)
|
||||
|
||||
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
|
||||
|
||||
function(example_compile_options EXAMPLE_NAME)
|
||||
if(TARGET ${EXAMPLE_NAME})
|
||||
target_compile_options(${EXAMPLE_NAME} ${ARGN})
|
||||
endif()
|
||||
endfunction(example_compile_options)
|
||||
|
||||
# add all example subdir
|
||||
file(GLOB dir_list LIST_DIRECTORIES true *)
|
||||
FOREACH(subdir ${dir_list})
|
||||
|
||||
@@ -25,7 +25,7 @@ execute_process(
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
@@ -34,7 +34,7 @@ execute_process(
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
|
||||
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
|
||||
@@ -57,7 +57,7 @@ add_custom_command(
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_FMHA_FWD}")
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
|
||||
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
@@ -65,7 +65,7 @@ target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_FMHA_BWD}")
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
|
||||
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
|
||||
@@ -71,6 +71,7 @@ args:
|
||||
-drop_seed seed for random number generator (default:1)
|
||||
-drop_offset offset for random number generator (default:0)
|
||||
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
|
||||
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
```
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
@@ -117,8 +117,50 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cu) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cu = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
@@ -134,36 +176,50 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return 'true'
|
||||
else:
|
||||
return f'{self.bool_expr}'
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f'({str(self)}) && ({str(other)})')
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
constraint : CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -220,17 +276,18 @@ class FmhaFwdApiTrait:
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -297,8 +354,8 @@ class FmhaFwdApiPool:
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
@@ -313,25 +370,27 @@ class FmhaFwdApiPool:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
@@ -423,33 +482,21 @@ class FmhaFwdKernel:
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'128' : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
@@ -458,54 +505,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'128' : [FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
d = CustomFactory.get_hdim_tile_size_dict(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
tiles = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
|
||||
@@ -58,7 +58,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -94,7 +95,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -129,9 +130,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -160,11 +161,12 @@ class FmhaFwdApiTrait:
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
@@ -227,6 +229,7 @@ class FmhaFwdPipeline:
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -262,8 +265,12 @@ class FmhaFwdPipeline:
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_skip == 't' : n += '_skip'
|
||||
else: n += '_nskip'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
@@ -275,31 +282,32 @@ class FmhaFwdApiPool:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][(hdim, hdim_v)]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -381,6 +389,7 @@ class FmhaFwdKernel:
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
@@ -419,25 +428,28 @@ class FmhaFwdKernel:
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -445,7 +457,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
@@ -453,36 +465,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -498,17 +510,15 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for pipeline in get_pipelines(dtype, hdim, hdim_v):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
if (hdim, hdim_v) == (192, 128) or hdim == 160:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
@@ -532,6 +542,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -540,6 +551,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
@@ -565,6 +577,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
# 160: 160,
|
||||
256: 256
|
||||
}
|
||||
|
||||
@@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
|
||||
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_bwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_hp_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); });
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
randval_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
|
||||
});
|
||||
ck_tile::reference_batched_dropout(
|
||||
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
|
||||
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
else
|
||||
{
|
||||
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
}
|
||||
|
||||
// O = P * V
|
||||
@@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
auto idx_gmo = idx_gmn;
|
||||
idx_gmo[2] = o;
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
|
||||
}
|
||||
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
|
||||
});
|
||||
ck_tile::make_ParallelTensorFunctor(
|
||||
[&](auto i0, auto i1, auto i2) {
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
|
||||
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
|
||||
}
|
||||
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
|
||||
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
|
||||
},
|
||||
ds_hp_host_ref.mDesc.get_lengths()[0],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
|
||||
});
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
}
|
||||
|
||||
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
|
||||
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
|
||||
});
|
||||
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
|
||||
|
||||
// dV = P_drop^T@dO^T
|
||||
// dV = P^T@dO^T w/o dropout
|
||||
|
||||
47
example/ck_tile/01_fmha/fmha_fwd.cpp
Normal file → Executable file
47
example/ck_tile/01_fmha/fmha_fwd.cpp
Normal file → Executable file
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
@@ -178,50 +178,30 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
|
||||
}
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
||||
{
|
||||
// If we have enough to almost fill the SMs, then just use 1 split
|
||||
if(batch_nhead_mblocks >= 0.8f * num_SMs)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
||||
max_splits = std::min({max_splits, num_SMs});
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
||||
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
||||
// (i.e. it's 11 splits anyway).
|
||||
// So we check if the number of blocks per split is the same as the previous num_splits.
|
||||
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
||||
return num_splits == 1 ||
|
||||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
||||
};
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
efficiency.push_back(0.f);
|
||||
}
|
||||
else
|
||||
{
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
|
||||
{
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
@@ -234,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks,
|
||||
int override_num_splits_if_necessary(
|
||||
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
||||
{
|
||||
(void)hdim_v;
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
@@ -250,15 +231,13 @@ int override_num_splits_if_necessary(
|
||||
|
||||
// tile size should match the generate.py
|
||||
const int kM0 = 64;
|
||||
const int kN1 = hdim_v;
|
||||
|
||||
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
||||
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
if(num_splits < 1 && p_drop == 0.0f)
|
||||
{
|
||||
return num_splits_heuristic(
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
|
||||
}
|
||||
|
||||
return num_splits;
|
||||
@@ -542,8 +521,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
max_seqlen_k = real_seqlen_k;
|
||||
}
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * real_seqlen_q * real_seqlen_k * hdim_q +
|
||||
static_cast<std::size_t>(2) * real_seqlen_q * hdim_v * real_seqlen_k);
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
|
||||
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
|
||||
sizeof(KDataType) * real_seqlen_k * hdim_q +
|
||||
|
||||
@@ -169,6 +169,7 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
@@ -433,6 +434,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
@@ -713,102 +715,102 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
return FmhaKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_total_pages,
|
||||
args.kv_indptr,
|
||||
args.kv_page_indices,
|
||||
#if 0 // we assume page_block_size=1 for now
|
||||
args.kv_last_page_lens,
|
||||
args.page_block_size,
|
||||
#endif
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.logits_soft_cap,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -837,7 +839,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -861,6 +864,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -995,6 +999,7 @@ struct fmha_fwd_traits
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
bool skip_min_seqlen_q = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
21
example/ck_tile/01_fmha/mask.hpp
Normal file → Executable file
21
example/ck_tile/01_fmha/mask.hpp
Normal file → Executable file
@@ -21,6 +21,8 @@ enum class mask_enum
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
@@ -42,6 +44,8 @@ struct mask_info
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
tmp.seqlen_q = seqlen_q;
|
||||
tmp.seqlen_k = seqlen_k;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
@@ -148,7 +152,22 @@ struct mask_info
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
ck_tile::index_t get_unmaskarea() const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
return seqlen_q * seqlen_k;
|
||||
ck_tile::index_t area = 0;
|
||||
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
||||
{
|
||||
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
||||
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
||||
if(x_end > x_start)
|
||||
{
|
||||
area += (x_end - x_start);
|
||||
}
|
||||
}
|
||||
return area;
|
||||
}
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
|
||||
@@ -25,7 +25,7 @@ add_custom_command(
|
||||
|
||||
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
|
||||
|
||||
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
|
||||
message(DEBUG "adding example ${EXAMPLE_LAYERNORM2D_FWD}")
|
||||
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
|
||||
|
||||
@@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ args:
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
|
||||
@@ -12,15 +12,23 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
if constexpr(Persistent)
|
||||
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
@@ -50,8 +58,10 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
@@ -60,9 +70,12 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
@@ -128,12 +141,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -144,24 +157,24 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
@@ -199,19 +212,39 @@ int run_gemm_example(int argc, char* argv[])
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "i8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::int8_t, ck_tile::int8_t, int32_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
@@ -14,99 +13,28 @@
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_ASYNC 4
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_ASYNC
|
||||
#endif
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_ASYNC)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompAsync
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompAsync
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#else
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
|
||||
struct GemmConfig
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_ASYNC)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
@@ -120,6 +48,169 @@ struct GemmConfig
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
@@ -171,6 +262,15 @@ struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using BDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CDataType = int32_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
@@ -186,6 +286,12 @@ struct DataTypeTraits<double>
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
@@ -216,6 +322,51 @@ struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -234,11 +385,23 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent = false,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -30,7 +30,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename Tensor,
|
||||
template <typename GemmConfig,
|
||||
typename Tensor,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -63,11 +64,12 @@ void permute_tensor_b(Tensor& tensor)
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GEMM_PIPELINE_SCHEDULER,
|
||||
GemmConfig::Scheduler,
|
||||
true,
|
||||
ck_tile::TailNumber::Full>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
@@ -144,13 +146,31 @@ void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
@@ -162,23 +182,55 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
int n_repeat,
|
||||
bool persistent)
|
||||
{
|
||||
ck_tile::GemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
|
||||
float ave_time =
|
||||
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
float ave_time;
|
||||
if(persistent)
|
||||
{
|
||||
ave_time = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
true,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -193,13 +245,14 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
<< " B_Type=" << DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
@@ -229,6 +282,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool persistent = arg_parser.get_int("persistent");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -243,8 +297,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
@@ -278,7 +332,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<decltype(b_k_n_dev),
|
||||
permute_tensor_b<GemmConfig,
|
||||
decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -304,19 +359,28 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
persistent);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
@@ -351,29 +415,19 @@ int run_gemm_example_with_layouts(int argc,
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
// memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
// memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
|
||||
ck_tile::hip_check_error(
|
||||
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_A,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a_m_k.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n.get_element_space_size_in_bytes(),
|
||||
hipMemcpyHostToDevice));
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
@@ -383,16 +437,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
|
||||
d_C,
|
||||
c_m_n_dev_result.get_element_space_size_in_bytes(),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
|
||||
@@ -11,28 +11,22 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename Pipeline, ck_tile::TailNumber TN>
|
||||
void try_run(ck_tile::TailNumber tn)
|
||||
{
|
||||
if constexpr(Pipeline::PrefetchStages > static_cast<int>(TN))
|
||||
{
|
||||
if(tn == TN)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, TN>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -41,30 +35,36 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity>;
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
@@ -74,64 +74,118 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
@@ -150,103 +204,14 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
|
||||
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
|
||||
};
|
||||
|
||||
check_tail(ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4 || \
|
||||
CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_ASYNC)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
@@ -256,14 +221,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(a_layout == "C" && b_layout == "C")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
// argc, argv, Col{}, Col{}, Row{});
|
||||
// }
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
@@ -272,26 +237,26 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
else
|
||||
{
|
||||
// if(a_layout == "R" && b_layout == "R")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
// argc, argv, Row{}, Row{}, Row{});
|
||||
// }
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(a_layout == "C" && b_layout == "R")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
// argc, argv, Col{}, Row{}, Row{});
|
||||
// }
|
||||
// else if(a_layout == "C" && b_layout == "C")
|
||||
// {
|
||||
// return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
// argc, argv, Col{}, Col{}, Row{});
|
||||
// }
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
@@ -299,7 +264,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
@@ -311,31 +276,50 @@ int run_gemm_example(int argc, char* argv[])
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
@@ -346,7 +330,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
run_gemm_example(argc, argv);
|
||||
return !run_gemm_example<GemmConfigComputeV3>(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
set(EXAMPLE_REDUCE "tile_example_reduce")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_REDUCE}")
|
||||
message(DEBUG "adding example ${EXAMPLE_REDUCE}")
|
||||
|
||||
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp)
|
||||
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
@@ -35,7 +35,7 @@ struct Reduce2dShape
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
|
||||
@@ -25,7 +25,7 @@ add_custom_command(
|
||||
|
||||
set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd")
|
||||
|
||||
message("adding ${TILE_RMSNORM2D_FWD}")
|
||||
message(DEBUG "adding ${TILE_RMSNORM2D_FWD}")
|
||||
add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp)
|
||||
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
|
||||
|
||||
@@ -74,22 +74,22 @@ struct rmsnorm2d_fwd_traits_
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -97,13 +97,13 @@ struct rmsnorm2d_fwd_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -712,4 +712,4 @@ if __name__ == "__main__":
|
||||
if args.list_blobs:
|
||||
list_blobs(args)
|
||||
else:
|
||||
gen_blobs(args)
|
||||
gen_blobs(args)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
|
||||
message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp)
|
||||
target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
@@ -67,13 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using TypeConfig = AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
using UnquantYDataType = ck_tile::null_type;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
|
||||
@@ -184,6 +185,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Rmsnorm2d
|
||||
{
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n});
|
||||
|
||||
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
|
||||
// simplicity
|
||||
@@ -191,8 +193,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
|
||||
@@ -80,22 +80,23 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
|
||||
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr auto WarpSize = ck_tile::get_warp_size();
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -103,13 +104,13 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Rmsnorm2d
|
||||
{
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
|
||||
ck_tile::HostTensor<ck_tile::null_type> unquant_y_host_ref({m, n});
|
||||
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
|
||||
// simplicity
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
@@ -194,7 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon);
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
function (add_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
message("adding ${TARGET_NAME}")
|
||||
message(DEBUG "adding ${TARGET_NAME}")
|
||||
# not using add_example_executable() to add target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
|
||||
|
||||
@@ -49,22 +49,22 @@ struct smoothquant_traits_
|
||||
{
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -72,13 +72,13 @@ struct smoothquant_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -14,14 +14,24 @@ This will result in an executable `build/bin/tile_example_moe_sorting`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-pr_i index data type. (currently only fp32 supported now) (default:int32)
|
||||
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
|
||||
-t number of input tokens (default:32)
|
||||
-e number of experts (default:8)
|
||||
-k topk (default:2)
|
||||
-st_i row stride of input, -1 means same as experts (default:-1)
|
||||
-seed seed to be used, -1 means random every time (default:-1)
|
||||
-kname when set to 1 it will print kernel name (default:0)
|
||||
-v turn CPU validation on (1) or off (0). (default:1)
|
||||
-pr_i index data type. Only int32 is currently supported. (default:int32)
|
||||
-pr_w output weight data type. Only fp32 is currently supported. (default:fp32)
|
||||
-t number of input tokens. (default:128)
|
||||
If "local_t" presents, this value indicates global concurrency of all ranks.
|
||||
-local_t Number of local input tokens for curent rank. (default:-1)
|
||||
This value must be within range "[0, t)", or "-1"(no such feature)
|
||||
This feature is to simulate EP case where where each rank has different tokens.
|
||||
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
|
||||
-e number of num_experts (default:8)
|
||||
-k topk (default:4)
|
||||
-unit unit_size (default:32)
|
||||
-moe_buf_size moe_buf_size (default:0)
|
||||
-local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1)
|
||||
please make sure eid is in ascending order!
|
||||
-seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1)
|
||||
-kname prints the kernel name when set to 1 (default:0)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
|
||||
```
|
||||
|
||||
@@ -18,10 +18,20 @@
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
|
||||
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
|
||||
.insert("t", "128", "number of input tokens")
|
||||
arg_parser.insert("v", "1", "turn CPU validation on (1) or off (0).")
|
||||
.insert("pr_i", "int32", "index data type. Only int32 is currently supported.")
|
||||
.insert("pr_w", "fp32", "output weight data type. Only fp32 is currently supported.")
|
||||
.insert("t",
|
||||
"128",
|
||||
"number of input tokens.\n"
|
||||
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
|
||||
.insert(
|
||||
"local_t",
|
||||
"-1",
|
||||
"Number of local input tokens for curent rank.\n"
|
||||
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
|
||||
"This feature is to simulate EP case where where each rank has different tokens.\n"
|
||||
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
@@ -30,8 +40,11 @@ auto create_args(int argc, char* argv[])
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
"please make sure eid is in ascending order!")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("seed",
|
||||
"-1",
|
||||
"seed to be used. When set to -1, a random seed will be generated each time "
|
||||
"invoking this example")
|
||||
.insert("kname", "0", "prints the kernel name when set to 1")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -70,6 +83,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
int tokens = args.get_int("t");
|
||||
int local_tokens = args.get_int("local_t");
|
||||
int num_experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
@@ -95,6 +109,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
return false;
|
||||
}
|
||||
|
||||
// if local_tokens == tokens, not local_token, but better avoid this since no meaning for such
|
||||
// case
|
||||
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
|
||||
|
||||
if(local_tokens > tokens)
|
||||
{
|
||||
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool local_expert_masking = args.get_str("local_eid") != "-1";
|
||||
auto local_expert_masking_host = [&]() {
|
||||
if(local_expert_masking)
|
||||
@@ -143,6 +167,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ck_tile::DeviceMem local_expert_masking_dev(
|
||||
local_expert_masking_host.get_element_space_size_in_bytes());
|
||||
|
||||
// used for simulating dynamic_tokens for EP case
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
@@ -164,6 +195,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
sorted_ids_dev.GetDeviceBuffer(),
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
@@ -236,13 +268,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, mp:%d, ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
workspace_size != 0 ? 1 : 0);
|
||||
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
printf("(%d)", local_tokens);
|
||||
}
|
||||
printf(", num_experts:%d, topk:%d, mp:%d, ", num_experts, topk, workspace_size != 0 ? 1 : 0);
|
||||
|
||||
if(local_expert_masking)
|
||||
{
|
||||
@@ -285,6 +316,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ref_total_tokens_post_pad,
|
||||
num_experts,
|
||||
unit_size,
|
||||
is_local_token ? local_tokens
|
||||
: tokens,
|
||||
local_expert_masking);
|
||||
printf("total_tokens_post_pad:%d(%d), ",
|
||||
ref_total_tokens_post_pad,
|
||||
@@ -334,16 +367,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
|
||||
try
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec == "fp32" && index_prec == "int32")
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
|
||||
@@ -33,15 +33,18 @@
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_( \
|
||||
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -51,32 +54,43 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
|
||||
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
@@ -171,6 +185,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
@@ -179,15 +194,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -195,15 +212,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -211,15 +230,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -227,15 +248,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -244,15 +267,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -261,28 +286,53 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
|
||||
return ave_time; \
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
|
||||
@@ -31,4 +31,14 @@ $EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840
|
||||
$EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840
|
||||
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
|
||||
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
|
||||
$EXE -t=99 -local_t=93 -e=121 -moe_buf_size=10244
|
||||
$EXE -t=536 -local_t=345 -e=802 -k=99
|
||||
$EXE -t=331 -local_t=39 -e=83 -k=33
|
||||
$EXE -t=765 -local_t=654 -e=783 -k=8
|
||||
$EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
|
||||
message("adding ${TARGET_NAME}")
|
||||
message(DEBUG "adding ${TARGET_NAME}")
|
||||
# not using add_example_executable() to add target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
|
||||
|
||||
@@ -38,22 +38,22 @@ struct moe_smoothquant_traits_
|
||||
using InputType = ck_tile::remove_cvref_t<InputType_>;
|
||||
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (warpSize / ThreadPerBlock_N_);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / warpSize);
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -61,13 +61,13 @@ struct moe_smoothquant_traits_
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(warpSize % ThreadPerBlock_N_ == 0);
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % warpSize == 0);
|
||||
return ThreadPerBlock_N_ / warpSize;
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
|
||||
message(DEBUG "adding ${TILE_EXAPMLE_FUSED_MOE}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
|
||||
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
@@ -16,6 +16,7 @@ struct fused_moe_args
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
|
||||
const void* local_tokens; // [1] if not nullptr, tokens read from here
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* ws_ptr; // size is moe_sorting_get_workspace_size()
|
||||
// if return zero, then could be nullptr
|
||||
|
||||
@@ -28,6 +28,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
|
||||
a.local_tokens,
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
|
||||
@@ -33,15 +33,18 @@
|
||||
|
||||
#else
|
||||
|
||||
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_( \
|
||||
sub_token_tile_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
|
||||
constexpr bool sub_token_onshot = sub_token_onshot_; \
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking>; \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -51,32 +54,43 @@
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
|
||||
#define MOE_SORTING_DISPATCH_SUB_TOKEN_( \
|
||||
row_, sub_token_onshot_, local_expert_masking_, local_token_) \
|
||||
if(row_ % 8 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 4 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else if(row_ % 2 == 0) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_, local_token_); \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
|
||||
#define MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, true) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_, false) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
|
||||
if(is_sub_token_onshot) \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, true, local_expert_masking_) \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
MOE_SORTING_DISPATCH_DYNAMIC_TOKEN_(row_, false, local_expert_masking_) \
|
||||
}
|
||||
|
||||
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
|
||||
@@ -175,6 +189,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
auto row_ = sub_token_ / 8;
|
||||
bool is_sub_token_onshot = a.tokens <= sub_token_;
|
||||
bool is_local_expert_masking = t.local_expert_masking;
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
|
||||
MOE_SORTING_DISPATCH_EMASK_(row_);
|
||||
// MOE_SORTING_DISPATCH_ETILE(0, 0);
|
||||
@@ -183,15 +198,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return -1;
|
||||
}
|
||||
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -199,15 +216,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -215,15 +234,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -231,15 +252,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -248,15 +271,17 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
}()
|
||||
#endif
|
||||
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \
|
||||
#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
[&]() { \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking>; \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -265,30 +290,55 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \
|
||||
return ave_time; \
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
if(t.local_expert_masking) \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
if(is_local_token) \
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
return ave_time; \
|
||||
} \
|
||||
}
|
||||
|
||||
float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
fused_moesorting_args a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
@@ -360,3 +410,8 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
}
|
||||
|
||||
@@ -87,7 +87,18 @@ void topid_unique_gen(
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("t", "128", "num input tokens")
|
||||
arg_parser
|
||||
.insert("t",
|
||||
"128",
|
||||
"number of input tokens.\n"
|
||||
"If \"local_t\" presents, this value indicates global concurrency of all ranks.")
|
||||
.insert(
|
||||
"local_t",
|
||||
"-1",
|
||||
"Number of local input tokens for curent rank.\n"
|
||||
"This value must be within range \"[0, t)\", or \"-1\"(no such feature)\n"
|
||||
"This feature is to simulate EP case where where each rank has different tokens.\n"
|
||||
"Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.")
|
||||
.insert("e", "32", "num of experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("h", "8192", "hidden_size of this model")
|
||||
@@ -131,6 +142,7 @@ template <typename I, typename W, typename O, typename ST, typename SW, typename
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t local_tokens = arg_parser.get_int("local_t");
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
@@ -169,6 +181,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// w1 (Down, N size)
|
||||
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
|
||||
|
||||
bool is_local_token = local_tokens >= 0 && local_tokens < tokens;
|
||||
|
||||
if(local_tokens > tokens)
|
||||
{
|
||||
printf("local_tokens:%d larger than tokens:%d, invalid\n", local_tokens, tokens);
|
||||
return false;
|
||||
}
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_w)
|
||||
@@ -198,11 +218,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens;
|
||||
|
||||
if(is_local_token)
|
||||
{
|
||||
std::cout << "(" << local_tokens << ")";
|
||||
}
|
||||
|
||||
std::cout
|
||||
<< "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", act:"
|
||||
<< ", e:" << experts << ", k:" << topk << stride_str << ", hidden:" << hidden_size
|
||||
<< ", interm:" << intermediate_size << ", tp:" << tp << ", act:"
|
||||
<< activation
|
||||
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
|
||||
@@ -377,6 +403,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
@@ -400,6 +431,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
local_expert_masking ? local_expert_mask_buf.GetDeviceBuffer()
|
||||
: nullptr,
|
||||
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
@@ -463,6 +495,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
local_expert_masking);
|
||||
if(activation == 0)
|
||||
{
|
||||
@@ -495,6 +528,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
local_expert_masking);
|
||||
|
||||
// done, preparing GPU buffer
|
||||
@@ -506,6 +540,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
ck_tile::DeviceMem local_tokens_dev(sizeof(ck_tile::index_t));
|
||||
if(is_local_token)
|
||||
{
|
||||
local_tokens_dev.ToDevice(&local_tokens);
|
||||
}
|
||||
|
||||
// manually clear output buffer for atomic
|
||||
o_buf.SetZero();
|
||||
@@ -542,7 +581,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
intermediate_size / tp,
|
||||
tokens,
|
||||
is_local_token ? local_tokens : tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user