Merge remote-tracking branch 'origin/develop' into vpietila/split-k-param-auto-deduce

This commit is contained in:
Ville Pietilä
2025-06-12 09:45:16 +00:00
523 changed files with 57888 additions and 6825 deletions

0
.pre-commit-config.yaml Executable file → Normal file
View File

View File

@@ -13,10 +13,12 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
* Added support for Stream-K version of mixed fp8/bf16 GEMM
* Added GEMM pipeline for microscaling (MX) data types
* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added logit soft-capping support for fMHA forward kernels.
* Added benchmarking support for tile engine GEMM.
* Added rotating buffer feature for CK_Tile GEMM.
### Optimized

View File

@@ -26,17 +26,21 @@ set(version 1.1.0)
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
include(CTest)
option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
# CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
if(NOT CK_USE_ALTERNATIVE_PYTHON)
find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED)
else()
message("Using alternative python version")
message(STATUS "Using alternative python version")
set(EXTRA_PYTHON_PATH)
# this is overly restrictive, we may need to be more flexible on the following
string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}")
message("alternative python path is: ${EXTRA_PYTHON_PATH}")
message(STATUS "alternative python path is: ${EXTRA_PYTHON_PATH}")
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}")
set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}")
@@ -76,7 +80,7 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_BF16)
set(CK_ENABLE_BF16 "ON")
endif()
message("DTYPES macro set to ${DTYPES}")
message(STATUS "DTYPES macro set to ${DTYPES}")
else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8)
set(CK_ENABLE_INT8 "ON")
@@ -142,8 +146,8 @@ rocm_setup_version(VERSION ${version})
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}")
message("GPU_TARGETS= ${GPU_TARGETS}")
message("GPU_ARCHS= ${GPU_ARCHS}")
message(STATUS "GPU_TARGETS= ${GPU_TARGETS}")
message(STATUS "GPU_ARCHS= ${GPU_ARCHS}")
if(GPU_ARCHS)
#disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package
unset(GPU_TARGETS CACHE)
@@ -158,9 +162,9 @@ find_package(hip REQUIRED)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
message("hip_version_flat=${hip_VERSION_FLAT}")
message(STATUS "hip_version_flat=${hip_VERSION_FLAT}")
message("checking which targets are supported")
message(STATUS "checking which targets are supported")
#In order to build just the CK library (without tests and examples) for all supported GPU targets
#use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts.
@@ -172,8 +176,10 @@ if(NOT ENABLE_ASAN_PACKAGING)
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600300000 AND ${hip_VERSION_FLAT} LESS 600400000)
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000)
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000 AND ${hip_VERSION_FLAT} LESS 600443483)
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950")
elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600443483)
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx950;gfx10-3-generic;gfx11-generic;gfx12-generic")
endif()
else()
#build CK only for xnack-supported targets when using ASAN
@@ -197,25 +203,25 @@ endif()
rocm_check_target_ids(SUPPORTED_GPU_TARGETS
TARGETS ${CK_GPU_TARGETS})
message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
message("Enabling XDL instances")
message(STATUS "Enabling XDL instances")
add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
message("Enabling XDL FP8 gemms on native architectures")
message(STATUS "Enabling XDL FP8 gemms on native architectures")
add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA instances")
message(STATUS "Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA FP8 gemms on native architectures")
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
add_definitions(-DCK_USE_WMMA_FP8)
set(CK_USE_WMMA_FP8 "ON")
endif()
@@ -244,32 +250,32 @@ configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/con
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302)
check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK)
if(HAS_NO_OFFLOAD_UNIFORM_BLOCK)
message("Adding the fno-offload-uniform-block compiler flag")
message(STATUS "Adding the fno-offload-uniform-block compiler flag")
add_compile_options(-fno-offload-uniform-block)
endif()
endif()
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000)
check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION)
if(HAS_LSR_DROP_SOLUTION)
message("Adding the lsr-drop-solution=1 compiler flag")
message(STATUS "Adding the lsr-drop-solution=1 compiler flag")
add_compile_options("SHELL: -mllvm --lsr-drop-solution=1")
endif()
endif()
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090)
check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED)
if(HAS_ENABLE_POST_MISCHED)
message("Adding the enable-post-misched=0 compiler flag")
message(STATUS "Adding the enable-post-misched=0 compiler flag")
add_compile_options("SHELL: -mllvm -enable-post-misched=0")
endif()
endif()
set(check-coerce)
check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132)
message("Adding the amdgpu-coerce-illegal-types=1")
message(STATUS "Adding the amdgpu-coerce-illegal-types=1")
add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1")
endif()
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
message("Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false")
message(STATUS "Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false")
add_compile_options("SHELL: -mllvm -amdgpu-early-inline-all=true")
add_compile_options("SHELL: -mllvm -amdgpu-function-calls=false")
endif()
@@ -306,13 +312,13 @@ option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX1
if(USE_BITINT_EXTENSION_INT4)
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
add_compile_options(-Wno-bit-int-extension)
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif()
if(USE_OPT_GFX11)
add_compile_options(-mcumode)
add_compile_options(-mno-wavefrontsize64)
message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
endif()
## Threads
@@ -324,7 +330,7 @@ link_libraries(Threads::Threads)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
# https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_macros.html
# _GLIBCXX_ASSERTIONS
@@ -340,7 +346,7 @@ endif()
set(CMAKE_HIP_PLATFORM amd)
set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER})
set(CMAKE_HIP_EXTENSIONS ON)
message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
message(STATUS "CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}")
## OpenMP
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
@@ -355,10 +361,10 @@ else()
find_package(OpenMP REQUIRED)
endif()
message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
message(STATUS "OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
message(STATUS "OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
message(STATUS "OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
message(STATUS "OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries(${OpenMP_pthread_LIBRARY})
@@ -390,146 +396,152 @@ else()
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
endif()
## tidy
include(EnableCompilerWarnings)
## tidy
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
# Enable tidy on hip
elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU")
set(CK_TIDY_ERRORS ALL)
set(CK_TIDY_ERRORS ALL)
endif()
include(ClangTidy)
enable_clang_tidy(
CHECKS
*
-abseil-*
-android-cloexec-fopen
# Yea we shouldn't be using rand()
-cert-msc30-c
-bugprone-exception-escape
-bugprone-macro-parentheses
-cert-env33-c
-cert-msc32-c
-cert-msc50-cpp
-cert-msc51-cpp
-cert-dcl37-c
-cert-dcl51-cpp
-clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-optin.performance.Padding
-clang-diagnostic-deprecated-declarations
-clang-diagnostic-extern-c-compat
-clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-avoid-c-arrays
-cppcoreguidelines-avoid-magic-numbers
-cppcoreguidelines-explicit-virtual-functions
-cppcoreguidelines-init-variables
-cppcoreguidelines-macro-usage
-cppcoreguidelines-non-private-member-variables-in-classes
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
-cppcoreguidelines-pro-bounds-constant-array-index
-cppcoreguidelines-pro-bounds-pointer-arithmetic
-cppcoreguidelines-pro-type-member-init
-cppcoreguidelines-pro-type-reinterpret-cast
-cppcoreguidelines-pro-type-union-access
-cppcoreguidelines-pro-type-vararg
-cppcoreguidelines-special-member-functions
-fuchsia-*
-google-explicit-constructor
-google-readability-braces-around-statements
-google-readability-todo
-google-runtime-int
-google-runtime-references
-hicpp-vararg
-hicpp-braces-around-statements
-hicpp-explicit-conversions
-hicpp-named-parameter
-hicpp-no-array-decay
# We really shouldn't use bitwise operators with signed integers, but
# opencl leaves us no choice
-hicpp-avoid-c-arrays
-hicpp-signed-bitwise
-hicpp-special-member-functions
-hicpp-uppercase-literal-suffix
-hicpp-use-auto
-hicpp-use-equals-default
-hicpp-use-override
-llvm-header-guard
-llvm-include-order
#-llvmlibc-*
-llvmlibc-restrict-system-libc-headers
-llvmlibc-callee-namespace
-llvmlibc-implementation-in-namespace
-llvm-else-after-return
-llvm-qualified-auto
-misc-misplaced-const
-misc-non-private-member-variables-in-classes
-misc-no-recursion
-modernize-avoid-bind
-modernize-avoid-c-arrays
-modernize-pass-by-value
-modernize-use-auto
-modernize-use-default-member-init
-modernize-use-equals-default
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-performance-unnecessary-value-param
-readability-braces-around-statements
-readability-else-after-return
# we are not ready to use it, but very useful
-readability-function-cognitive-complexity
-readability-isolate-declaration
-readability-magic-numbers
-readability-named-parameter
-readability-uppercase-literal-suffix
-readability-convert-member-functions-to-static
-readability-qualified-auto
-readability-redundant-string-init
# too many narrowing conversions in our code
-bugprone-narrowing-conversions
-cppcoreguidelines-narrowing-conversions
-altera-struct-pack-align
-cppcoreguidelines-prefer-member-initializer
${CK_TIDY_CHECKS}
${CK_TIDY_ERRORS}
HEADER_FILTER
"\.hpp$"
EXTRA_ARGS
-DCK_USE_CLANG_TIDY
)
if(ENABLE_CLANG_CPP_CHECKS)
include(ClangTidy)
enable_clang_tidy(
CHECKS
*
-abseil-*
-android-cloexec-fopen
# Yea we shouldn't be using rand()
-cert-msc30-c
-bugprone-exception-escape
-bugprone-macro-parentheses
-cert-env33-c
-cert-msc32-c
-cert-msc50-cpp
-cert-msc51-cpp
-cert-dcl37-c
-cert-dcl51-cpp
-clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-optin.performance.Padding
-clang-diagnostic-deprecated-declarations
-clang-diagnostic-extern-c-compat
-clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-avoid-c-arrays
-cppcoreguidelines-avoid-magic-numbers
-cppcoreguidelines-explicit-virtual-functions
-cppcoreguidelines-init-variables
-cppcoreguidelines-macro-usage
-cppcoreguidelines-non-private-member-variables-in-classes
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
-cppcoreguidelines-pro-bounds-constant-array-index
-cppcoreguidelines-pro-bounds-pointer-arithmetic
-cppcoreguidelines-pro-type-member-init
-cppcoreguidelines-pro-type-reinterpret-cast
-cppcoreguidelines-pro-type-union-access
-cppcoreguidelines-pro-type-vararg
-cppcoreguidelines-special-member-functions
-fuchsia-*
-google-explicit-constructor
-google-readability-braces-around-statements
-google-readability-todo
-google-runtime-int
-google-runtime-references
-hicpp-vararg
-hicpp-braces-around-statements
-hicpp-explicit-conversions
-hicpp-named-parameter
-hicpp-no-array-decay
# We really shouldn't use bitwise operators with signed integers, but
# opencl leaves us no choice
-hicpp-avoid-c-arrays
-hicpp-signed-bitwise
-hicpp-special-member-functions
-hicpp-uppercase-literal-suffix
-hicpp-use-auto
-hicpp-use-equals-default
-hicpp-use-override
-llvm-header-guard
-llvm-include-order
#-llvmlibc-*
-llvmlibc-restrict-system-libc-headers
-llvmlibc-callee-namespace
-llvmlibc-implementation-in-namespace
-llvm-else-after-return
-llvm-qualified-auto
-misc-misplaced-const
-misc-non-private-member-variables-in-classes
-misc-no-recursion
-modernize-avoid-bind
-modernize-avoid-c-arrays
-modernize-pass-by-value
-modernize-use-auto
-modernize-use-default-member-init
-modernize-use-equals-default
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-performance-unnecessary-value-param
-readability-braces-around-statements
-readability-else-after-return
# we are not ready to use it, but very useful
-readability-function-cognitive-complexity
-readability-isolate-declaration
-readability-magic-numbers
-readability-named-parameter
-readability-uppercase-literal-suffix
-readability-convert-member-functions-to-static
-readability-qualified-auto
-readability-redundant-string-init
# too many narrowing conversions in our code
-bugprone-narrowing-conversions
-cppcoreguidelines-narrowing-conversions
-altera-struct-pack-align
-cppcoreguidelines-prefer-member-initializer
${CK_TIDY_CHECKS}
${CK_TIDY_ERRORS}
HEADER_FILTER
"\.hpp$"
EXTRA_ARGS
-DCK_USE_CLANG_TIDY
)
include(CppCheck)
enable_cppcheck(
CHECKS
warning
style
performance
portability
SUPPRESS
ConfigurationNotChecked
constStatement
duplicateCondition
noExplicitConstructor
passedByValue
preprocessorErrorDirective
shadowVariable
unusedFunction
unusedPrivateFunction
unusedStructMember
unmatchedSuppression
FORCE
SOURCES
library/src
INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_BINARY_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/library/include
DEFINE
CPPCHECK=1
__linux__=1
)
include(CppCheck)
enable_cppcheck(
CHECKS
warning
style
performance
portability
SUPPRESS
ConfigurationNotChecked
constStatement
duplicateCondition
noExplicitConstructor
passedByValue
preprocessorErrorDirective
shadowVariable
unusedFunction
unusedPrivateFunction
unusedStructMember
unmatchedSuppression
FORCE
SOURCES
library/src
INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_BINARY_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/library/include
DEFINE
CPPCHECK=1
__linux__=1
)
else()
function(clang_tidy_check TARGET)
# stub out empty function if clang tidy is not enabled
endfunction()
endif()
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
@@ -548,7 +560,7 @@ if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
add_compile_options(-fcolor-diagnostics)
@@ -557,12 +569,15 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS
add_compile_options(-fdiagnostics-color=always)
endif()
# make check runs the entire set of examples and tests
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
if(NOT MIOPEN_REQ_LIBS_ONLY)
# make check runs the entire set of examples and tests
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
endif()
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
@@ -605,6 +620,11 @@ ENDIF()
ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
add_subdirectory(library)
if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
@@ -624,11 +644,13 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
endif()
endif()
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory(profiler)
if (NOT MIOPEN_REQ_LIBS_ONLY)
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory(profiler)
endif()
if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
add_subdirectory(codegen)

View File

@@ -1,6 +1,6 @@
FROM ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=6.4
ARG ROCMVERSION=6.4.1
ARG compiler_version=""
ARG compiler_commit=""
ARG CK_SCCACHE=""
@@ -13,8 +13,8 @@ RUN set -xe && \
curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
RUN if [ "$ROCMVERSION" != "6.5" ]; then \
sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60400-1_all.deb --no-check-certificate" && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60400-1_all.deb && \
sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \

View File

@@ -1,4 +1,4 @@
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4"
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1"
FROM $BASE_DOCKER
ARG compiler_version=""
ARG compiler_commit=""

174
Jenkinsfile vendored
View File

@@ -12,6 +12,23 @@ def show_node_info() {
"""
}
class Version {
int major, minor, patch
@Override
String toString() {
return [major, minor, patch].findAll().join('.')
}
}
def parseVersion(String versionString) {
if (!versionString) return null
int[] tokens = versionString.split(/\./).collect { it as int } // Splits the string by '.' and converts each part to an integer.
return new Version(
major: tokens[0],
minor: tokens.length > 1 ? tokens[1] : null,
patch: tokens.length > 2 ? tokens[2] : null,
)
}
def nthreads() {
def nproc = sh(returnStdout: true, script: 'nproc')
echo "Number of cores: ${nproc}"
@@ -38,8 +55,8 @@ def getBaseDockerImageName(){
img = "${params.USE_CUSTOM_DOCKER}"
}
else{
def ROCM_numeric = "${params.ROCMVERSION}" as float
if ( ROCM_numeric < 6.5 ){
def ROCM_numeric = parseVersion("${params.ROCMVERSION}")
if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){
img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}"
}
else{
@@ -114,6 +131,9 @@ def check_arch(){
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
arch_type = 6
}
else if ( runShell('grep -n "gfx950" rocminfo.log') ) {
arch_type = 7
}
return arch_type
}
@@ -132,6 +152,10 @@ def getDockerImage(Map conf=[:]){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using special docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
@@ -208,6 +232,11 @@ def cmake_build(Map conf=[:]){
def build_type_debug = (conf.get("build_type",'release') == 'debug')
// use special compiler for gfx950
if ( check_arch() == 7){
compiler = "/llvm-project/build/bin/clang++"
}
//cmake_env can overwrite default CXX variables.
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
@@ -263,6 +292,9 @@ def cmake_build(Map conf=[:]){
if (setup_args.contains("gfx94")){
invocation_tag="gfx94"
}
if (setup_args.contains("gfx95")){
invocation_tag="gfx95"
}
echo "invocation tag: ${invocation_tag}"
def redis_pre_setup_cmd = pre_setup_cmd
if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") {
@@ -422,16 +454,6 @@ def buildHipClangJob(Map conf=[:]){
env.HSA_ENABLE_SDMA=0
checkout scm
def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
@@ -455,7 +477,7 @@ def buildHipClangJob(Map conf=[:]){
echo "Docker flags: ${dockerOpts}"
def variant = env.STAGE_NAME
def image
def retimage
(retimage, image) = getDockerImage(conf)
@@ -496,17 +518,6 @@ def Build_CK(Map conf=[:]){
env.HSA_ENABLE_SDMA=0
env.DOCKER_BUILDKIT=1
checkout scm
def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
@@ -527,6 +538,7 @@ def Build_CK(Map conf=[:]){
echo "Docker flags: ${dockerOpts}"
def variant = env.STAGE_NAME
def image
def retimage
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
@@ -638,6 +650,13 @@ def Build_CK(Map conf=[:]){
archiveArtifacts "perf_onnx_gemm_gfx908.log"
stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908"
}
else if ( arch == 7 ){
// run basic tests on gfx950
echo "Run performance tests"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950"
archiveArtifacts "perf_onnx_gemm_gfx950.log"
stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950"
}
}
}
if (params.hipTensor_test && arch == 1 ){
@@ -774,8 +793,8 @@ def process_results(Map conf=[:]){
}
//launch develop branch daily jobs
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
@@ -800,8 +819,8 @@ pipeline {
description: 'If you want to use a custom docker image, please specify it here (default: leave blank).')
string(
name: 'ROCMVERSION',
defaultValue: '6.4',
description: 'Specify which ROCM version to use: 6.3 (default).')
defaultValue: '6.4.1',
description: 'Specify which ROCM version to use: 6.4.1 (default).')
string(
name: 'COMPILER_VERSION',
defaultValue: '',
@@ -848,8 +867,8 @@ pipeline {
description: "Run the grouped conv large cases tests (default: OFF)")
booleanParam(
name: "RUN_CODEGEN_TESTS",
defaultValue: false,
description: "Run codegen tests (default: OFF)")
defaultValue: true,
description: "Run codegen tests (default: ON)")
booleanParam(
name: "RUN_CK_TILE_FMHA_TESTS",
defaultValue: false,
@@ -862,6 +881,10 @@ pipeline {
name: "RUN_CK_TILE_GEMM_TESTS",
defaultValue: false,
description: "Run the ck_tile GEMM tests (default: OFF)")
booleanParam(
name: "RUN_TILE_ENGINE_GEMM_TESTS",
defaultValue: false,
description: "Run the tile_engine_gemm tests (default: OFF)")
booleanParam(
name: "BUILD_INSTANCES_ONLY",
defaultValue: false,
@@ -870,6 +893,10 @@ pipeline {
name: "BUILD_GFX908",
defaultValue: false,
description: "Build CK and run tests on gfx908 (default: OFF)")
booleanParam(
name: "BUILD_GFX950",
defaultValue: false,
description: "Build CK and run tests on gfx950 (default: OFF)")
booleanParam(
name: "BUILD_GFX12",
defaultValue: true,
@@ -1145,6 +1172,48 @@ pipeline {
}
}
}
stage("Run TILE_ENGINE_GEMM Tests")
{
parallel
{
stage("Run TILE_ENGINE_GEMM Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make benchmark_gemm -j && \
./bin/benchmark_gemm """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run TILE_ENGINE_GEMM Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
make benchmark_gemm -j && \
./bin/benchmark_gemm """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
}
}
stage("Build CK and run Tests")
{
@@ -1188,7 +1257,7 @@ pipeline {
cleanWs()
}
}
stage("Build CK for all gfx9 targets")
stage("Build CK and run Tests on gfx942")
{
when {
beforeAgent true
@@ -1203,6 +1272,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1210,6 +1280,29 @@ pipeline {
cleanWs()
}
}
stage("Build CK and run Tests on gfx950")
{
when {
beforeAgent true
expression { params.BUILD_GFX950.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx950") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
-DGPU_TARGETS="gfx950" \
-DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx950" \
-DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "rocm/composable_kernel-private:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
}
}
stage("Build CK and run Tests on gfx908")
{
when {
@@ -1223,6 +1316,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx908" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1243,6 +1337,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1250,7 +1345,7 @@ pipeline {
cleanWs()
}
}
stage("Build CK instances for different targets")
stage("Build CK instances for all supported targets")
{
when {
beforeAgent true
@@ -1276,11 +1371,12 @@ pipeline {
}
agent{ label rocmnode("gfx1030") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx10-3-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx1030" \
-DGPU_TARGETS="gfx10-3-generic" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1296,11 +1392,12 @@ pipeline {
}
agent{ label rocmnode("gfx1101") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx1101" \
-DGPU_TARGETS="gfx11-generic" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1316,11 +1413,12 @@ pipeline {
}
agent{ label rocmnode("gfx1201") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1201" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx1201" \
-DGPU_TARGETS="gfx12-generic" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{

View File

@@ -32,7 +32,7 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_BF16)
set(CK_ENABLE_BF16 "ON")
endif()
message("DTYPES macro set to ${DTYPES}")
message(DEBUG "DTYPES macro set to ${DTYPES}")
else()
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
set(CK_ENABLE_INT8 "ON")

View File

@@ -66,7 +66,8 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
# Werror set outside by BUILD_DEV
# -Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
@@ -108,7 +109,7 @@ else()
endif()
list(APPEND CMAKE_COMPILER_WARNINGS
-Wno-missing-field-initializers
-Wno-deprecated-declarations
-Wno-error=deprecated-declarations
)
endif()
add_definitions(${CMAKE_COMPILER_WARNINGS})

View File

@@ -19,9 +19,7 @@ list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
include(Embed)
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${CK_ROOT}/include/ck/*.hpp)
# printouts fot debug purposes
# message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
# message(STATUS "RELATIVE: ${CK_ROOT}/include")
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
add_compile_options(-std=c++17)
@@ -48,6 +46,7 @@ rocm_install_targets(
INCLUDE include
)
rocm_export_targets(
TARGETS ck_host ck_headers
EXPORT ck_host_targets
NAMESPACE composable_kernel::
)

View File

@@ -8,5 +8,5 @@ target_link_libraries(ck_rtc PUBLIC -lstdc++fs)
option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON)
if(USE_HIPRTC_FOR_CODEGEN_TESTS)
target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS)
message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
message(STATUS "CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}")
endif()

View File

@@ -1,2 +1,2 @@
rocm-docs-core[api_reference]==1.18.4
rocm-docs-core[api_reference]==1.20.1
sphinxcontrib-bibtex==2.6.3

View File

@@ -237,7 +237,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core[api-reference]==1.18.4
rocm-docs-core[api-reference]==1.20.1
# via -r requirements.in
rpds-py==0.24.0
# via

19
example/01_gemm/CMakeLists.txt Executable file → Normal file
View File

@@ -39,6 +39,12 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
set(GEMM_OPTIONS)
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-16")
example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS})
list(APPEND gpu_list gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
@@ -109,3 +115,16 @@ add_example_executable(example_gemm_wmma_bf16 gemm_wmma_bf16.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16)
add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_int8)
add_example_executable(example_gemm_wmma_bf16_v3 gemm_wmma_bf16_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_v3)
add_example_executable(example_gemm_wmma_bf16_pk_i4_v3 gemm_wmma_bf16_pk_i4_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_pk_i4_v3)
add_example_executable(example_gemm_wmma_fp8_v3 gemm_wmma_fp8_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3)
add_example_executable(example_gemm_wmma_fp16_v3 gemm_wmma_fp16_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_v3)
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3)
add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp)
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3)

View File

@@ -15,6 +15,8 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
@@ -57,8 +59,9 @@ struct ProblemSizeStreamK_universal final
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;
ck::index_t Grid_size = -1; // defaults to max occupancy
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
ck::index_t Grid_size = -1; // defaults to max occupancy
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
ck::StreamKReductionStrategy reduction_strategy = ck::StreamKReductionStrategy::Atomic;
};
struct ProblemSizeSplitK final
@@ -128,11 +131,12 @@ bool parse_cmd_args<ProblemSize>(int argc,
}
else
{
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
std::cerr
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
<< std::endl;
return false;
}
@@ -172,7 +176,19 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
if(argc >= 11)
{
problem_size.Streamk_sel = std::stoi(argv[10]);
problem_size.Grid_size = std::stoi(argv[11]);
if(argc >= 12)
{
problem_size.Grid_size = std::stoi(argv[11]);
if(argc >= 13)
{
int reduction_strategy = std::stoi(argv[12]);
problem_size.reduction_strategy = reduction_strategy == 0
? ck::StreamKReductionStrategy::Atomic
: ck::StreamKReductionStrategy::Reduction;
}
}
}
}
else
@@ -181,9 +197,12 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
<< std::endl
<< "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
<< std::endl
<< "arg11: Grid_size(-1 for max occupancy)" << std::endl
<< "arg12: Reduction strategy (0: Atomic, 1: Reduction)" << std::endl;
return false;
}
@@ -227,13 +246,14 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
}
else
{
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
std::cerr
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
<< std::endl
<< "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)"
<< "\narg11: Grid_size(-1 for max occupancy)" << std::endl;
return false;
}
@@ -277,12 +297,13 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
}
else
{
std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg10: KBatch" << std::endl;
std::cerr
<< "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)"
<< std::endl
<< "arg10: KBatch" << std::endl;
return false;
}

View File

@@ -0,0 +1,253 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::pk_i4_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 32;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
128, 128, KPerBlock,
8, 8,
16, 16,
4, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 1,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 1,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1,
ADataType, ADataType, PermuteA, PermuteB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
if constexpr(PermuteB)
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
}
}
}
}
else
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
}
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmV2Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
bool run_gemm_splitk_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
256,
128, 128, 32,
8, 8,
16, 16,
4, 2,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
using ADataType = ck::half_t;
using BDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
128, 128, 32,
8, 8,
16, 16,
4, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 1,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 1,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,302 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
using ADataType = ck::half_t;
using BDataType = ck::pk_i4_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 32;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
128, 128, KPerBlock,
8, 8,
16, 16,
4, 2,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 1,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 1,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1,
ADataType, ADataType, PermuteA, PermuteB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
if constexpr(PermuteB)
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
}
}
}
}
else
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
}
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmV2Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
bool run_gemm_splitk_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,47 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
128, 64,
64, 8, 8,
16, 16,
4, 2,
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, 1,
1, 1, S<1, 32, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp"
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using ComputeTypeA = ck::f8_t;
using ComputeTypeB = ck::f8_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault,
128,
128, 64, 64,
8, 8,
16, 16,
4, 2,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
1, 1, S<1, 32, 1, 4>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
ComputeTypeA, ComputeTypeB>;
// clang-format on
using ReferenceComputeType = ck::f8_t;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ReferenceComputeType,
ReferenceComputeType>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[])
{
if(!ck::is_gfx12_supported())
{
std::cout << "This kernel support gfx12 only" << std::endl;
return 0;
}
return !run_gemm_splitk_example(argc, argv);
}

0
example/01_gemm/gemm_xdl_bf16.cpp Executable file → Normal file
View File

0
example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp Executable file → Normal file
View File

0
example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp Executable file → Normal file
View File

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
@@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>;
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#else
// clang-format off

View File

@@ -33,7 +33,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)

View File

@@ -36,7 +36,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)

View File

@@ -21,6 +21,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto Grid_size = problem_size.Grid_size;
auto Streamk_sel = problem_size.Streamk_sel;
auto reduction_strategy = problem_size.reduction_strategy;
if(reduction_strategy == ck::StreamKReductionStrategy::Atomic)
{
std::cout << "Using Atomic reduction strategy" << std::endl;
}
else
{
std::cout << "Using Parallel reduction strategy" << std::endl;
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
@@ -35,7 +45,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
@@ -152,7 +162,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Grid_size,
a_element_op,
b_element_op,
c_element_op);
c_element_op,
reduction_strategy);
if(!gemm.IsSupportedArgument(argument))
{
@@ -242,7 +253,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
<< " GB/s, " << gemm.GetTypeString()
<< (reduction_strategy == ck::StreamKReductionStrategy::Atomic ? " (Atomic)"
: " (Reduction)")
<< std::endl;
}
return pass;
}

View File

@@ -34,7 +34,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
if(stride == -1 || stride == 0)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
@@ -34,7 +34,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
@@ -71,9 +71,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
@@ -84,14 +84,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
0, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
0, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
// clang-format on
#else

View File

@@ -1,11 +1,20 @@
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp16_bpreshuffle gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
set(EXAMPLE_COMPILE_OPTIONS)
# Open it when SGBPack branch landed on mainline
# list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental")
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp)
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp)
list(APPEND gpu_list gfx942 gfx950)
set(target 0)
@@ -19,14 +28,32 @@ foreach(gpu IN LISTS GPU_TARGETS)
if(HAS_MAX_ILP_SCHEDULING_STRATEGY)
list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1)
endif()
target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
endif()
set(GEMM_OPTIONS)
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
target_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
target_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
target_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
set(target 1)
endif()
endforeach()
set(GEMM_OPTIONS)
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
set(BLOCKSCALE_GEMM_OPTIONS)
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1")
check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL)
if(HAS_MAX_OCCUPANCY_EXPERIMENTAL)
list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental)
endif()
# list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1")
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
16, 128,
256, 16, 16,
128, 128,
128, 16, 16,
16, 16,
1, 2,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 16, 1, 16>, S<8>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 2, S<1, 32, 1, 8>, S<8>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// clang-format on
int main(int argc, char* argv[])

View File

@@ -0,0 +1,372 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using FP8 = ck::f8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = FP8;
using A1DataType = F32;
using B0DataType = FP8;
using B1DataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Row;
using A1Layout = Col;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K + k];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
using DeviceOpInstance =
ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle
// clang-format off
<Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
128, 128,
128, 16, 16,
16, 16,
8, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 1, S<1, 32, 1, 8>, S<8>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
bool flush_cache = true;
// GEMM shape
ck::index_t M = 128;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 8)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
flush_cache = std::stoi(argv[7]);
StrideA = K;
StrideB = K;
StrideE = N;
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: M, N, K\n");
printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n");
exit(0);
}
// Transpose the AScale tensor for better performance
ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<A1DataType> a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M,
(K + Scale_Block_K - 1) / Scale_Block_K,
Scale_Stride_AK,
A1Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B0DataType> b0_preshuffled(
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
(N + Scale_Block_N - 1) / Scale_Block_N,
Scale_Stride_BN,
B0Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 3:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 4:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
a1_device_buf.ToDevice(a1_m_k.mData.data());
b1_device_buf.ToDevice(b1_k_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
// do GEMM
auto device_op = DeviceOpInstance{};
int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{},
StrideE,
a1_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float ave_time = 0.0f;
if(flush_cache)
{
int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype;
ave_time = invoker.Run(argument,
StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf});
}
else
{
ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100});
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
if(do_verification)
{
Tensor<AccDataType> c_m_n({M, N});
Tensor<float> a_m_k({M, K});
Tensor<float> b_k_n({K, N});
for(int m = 0; m < M; m++)
{
for(int k = 0; k < K; k++)
{
a_m_k(m, k) = ck::type_convert<float>(a0_m_k(m, k)) *
a1_m_k(m / Scale_Block_M, k / Scale_Block_K);
}
}
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
b_k_n(k, n) = ck::type_convert<float>(b0_k_n(k, n)) *
b1_k_n(k / Scale_Block_K, n / Scale_Block_N);
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<float,
float,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
#if 1
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
e_m_n_host_result(m, n) = ck::type_convert<EDataType>(c_m_n(m, n));
}
}
#endif
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2)
? 0
: 1;
}
return 0;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -139,13 +139,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
128, 128, 128,
256, 256, 128,
16, 16,
32, 32,
4, 1,
16, 16,
16, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
2, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// clang-format on

View File

@@ -158,21 +158,22 @@ using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
@@ -183,15 +184,15 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM
// mn_perxdl
MNPerXDL, MNPerXDL,
// mn_xdlperwave
MXDLPerWave, NXDLPerWave,
MXDLPerWave, NXDLPerWave,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>;
2, 2, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>;
// clang-format on
@@ -205,9 +206,9 @@ int main(int argc, char* argv[])
ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t tokens = 64;
ck::index_t sorted_tile_num = 256;
ck::index_t valid_tile_num = 256;
ck::index_t tokens = 16384;
ck::index_t topk = 2;
if(argc == 1)
@@ -263,11 +264,12 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData = {valid_size};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
expert_ids.mData[i] = i / (valid_tile_num / experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
@@ -307,7 +309,7 @@ int main(int argc, char* argv[])
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.1, 0.1});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});

View File

@@ -0,0 +1,548 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using F32 = float;
using I64 = int64_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using A1DataType = F32;
using B0DataType = F8;
using B1DataType = F32;
// using EDataType = F16;
using EDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = EDataType;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D2Layout>;
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D2>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
{
// for real kernel use
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
{
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d2);
}
};
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
int KPack = 16 / sizeof(B0DataType);
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * static_cast<I64>(K) + k];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = true;
#if 0
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
// clang-format off
< Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock
BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, NPerBlock, KPerBlock,
// ak1, bk1
AK1, BK1,
// mn_perxdl
MNPerXDL, MNPerXDL,
// mn_xdlperwave
MXDLPerWave, NXDLPerWave,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, 128, 128,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
#endif
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
#if 1
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t experts = 8;
ck::index_t topk = 2;
// ck::index_t sorted_tile_num = 515;
// ck::index_t valid_tile_num = 512;
// ck::index_t tokens = 8192;
// ck::index_t sorted_tile_num = 15;
// ck::index_t valid_tile_num = 13;
ck::index_t sorted_tile_num = 259;
ck::index_t valid_tile_num = 256;
ck::index_t tokens = 4096;
#else
// deepseek
ck::index_t N = 2048;
ck::index_t K = 7168;
ck::index_t experts = 256;
ck::index_t topk = 8;
ck::index_t tokens = 4096;
ck::index_t sorted_tile_num = 261;
ck::index_t valid_tile_num = 256;
#endif
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
sorted_tile_num = std::stoi(argv[7]);
valid_tile_num = std::stoi(argv[8]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData = {valid_size};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<A1DataType> a1_t_k(HostTensorDescriptor(
{tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<B1DataType> b1_e_n_k(
HostTensorDescriptor({experts,
(K + Scale_Block_K - 1) / Scale_Block_K,
(N + Scale_Block_N - 1) / Scale_Block_N * 2},
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
a1_device_buf.ToDevice(a1_t_k.mData.data());
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
a1_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
sizeof(B0DataType) * K * N * 2 * experts +
sizeof(EDataType) * valid_tile_num * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s.\n"
<< device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> a_t_k({tokens, K});
Tensor<float> b_e_n_k({experts, K, N * 2});
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
// handle scale before ref.
for(int t = 0; t < tokens; ++t)
{
for(int k = 0; k < K; ++k)
{
a_t_k(t, k) = ck::type_convert<float>(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K);
}
}
for(int e = 0; e < experts; ++e)
{
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N * 2; ++n)
{
b_e_n_k(e, k, n) = ck::type_convert<float>(b0_e_n_k(e, k, n)) *
b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N);
}
}
}
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm1BlockScale<float,
float,
float,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough,
ActOP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a_t_k,
b_e_n_k,
d2_e_n,
c_t_k_n,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < valid_size; ++m)
{
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
if(t >= tokens)
{
continue;
}
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, topk_id, n) =
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
auto status =
ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
if(status == 0)
{
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

View File

@@ -123,11 +123,11 @@ using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 256;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 4;
static constexpr ck::index_t MXDLPerWave = 16;
static constexpr ck::index_t NXDLPerWave = 4;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t NPerBlock = 256;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
@@ -164,12 +164,12 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
// S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
@@ -186,11 +186,11 @@ int main(int argc, char* argv[])
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 16;
ck::index_t valid_tile_num = 13;
ck::index_t sorted_tile_num = 133;
ck::index_t valid_tile_num = 128;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t tokens = 128;
ck::index_t tokens = 16384;
ck::index_t topk = 2;
if(argc == 1)
@@ -245,13 +245,14 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
// max_token_id.mData[0] = valid_size;
// max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
}
if(tokens * topk > valid_size)
{

View File

@@ -0,0 +1,541 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using F32 = float;
using I64 = int64_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using A1DataType = F32;
using B0DataType = F8;
using B1DataType = F32;
using EDataType = F16;
// using EDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = EDataType;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
// using DsLayoutGate = ck::Tuple<D0Layout, D1Layout>;
using DsLayout = ck::Tuple<D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D2>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
{
// for real kernel use
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
{
// for real kernel use
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d2);
}
};
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
int KPack = 16 / sizeof(B0DataType);
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * static_cast<I64>(K) + k];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr bool MulRoutedWeight = true;
#if 0
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t CShuffleNLane = 16;
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, NPerBlock, KPerBlock,
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, NXDLPerWave,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, 128, 128,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
#endif
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// tokens = 1
// topk = 1
// experts = 8
// per expert:
constexpr ck::index_t valid_tile_num =
26; // 13 for 128; 52 for 32; 4096 for ds // > token * topk / MPerBlock
constexpr ck::index_t sorted_tile_num = valid_tile_num + 3;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
#if 1
// GEMM shape
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
#else
// deepseek
ck::index_t N = 2048;
ck::index_t K = 7160;
ck::index_t experts = 256;
ck::index_t tokens = 1;
ck::index_t topk = 8;
#endif
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N;
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
// int eids[] = {0, 1, 3, 3, 3};
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
// int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
// 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
// 3, 3, 3, 3, 3, 3, 3, 3, 4, 4,
// 5, 5, 5, 5, 6, 6, 6, 6, 7, 7,
// 7, 7,
// 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<A1DataType> a1_t_k_k(
HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K},
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B1DataType> b1_e_n_k(HostTensorDescriptor(
{experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N},
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{1.0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{1.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{1.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{1.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{1.0, 1.0});
for(auto i = 0; i < N * K; i++)
{
b0_e_n_k.mData[i] = ck::type_convert<B0DataType>(static_cast<float>(0.1));
b0_e_n_k.mData[i + N * K] = ck::type_convert<B0DataType>(static_cast<float>(0.2));
}
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
a1_device_buf.ToDevice(a1_t_k_k.mData.data());
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
a1_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * tokens * topk * N * K;
std::size_t num_btype = sizeof(A0DataType) * tokens * K * topk +
sizeof(B0DataType) * K * N * experts +
sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s.\n"
<< device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> a_t_k_k({tokens, topk, K});
Tensor<float> b_e_n_k({experts, K, N});
Tensor<float> c_t_n({tokens, N});
for(int t = 0; t < tokens; ++t)
{
for(int tk = 0; tk < topk; ++tk)
{
for(int k = 0; k < K; ++k)
{
a_t_k_k(t, tk, k) = ck::type_convert<float>(a0_t_k_k(t, tk, k)) *
a1_t_k_k(t, tk, k / Scale_Block_K);
}
}
}
for(int e = 0; e < experts; ++e)
{
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N; ++n)
{
b_e_n_k(e, k, n) = ck::type_convert<float>(b0_e_n_k(e, k, n)) *
b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N);
}
}
}
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
float,
float,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
CDEElementOp,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a_t_k_k,
b_e_n_k,
d2_e_n,
c_t_n,
PassThrough{},
PassThrough{},
cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
auto status =
ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
if(status == 0)
{
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

0
example/66_complex_contraction_bilinear/CMakeLists.txt Executable file → Normal file
View File

0
example/66_complex_contraction_bilinear/README.md Executable file → Normal file
View File

View File

@@ -6,6 +6,33 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
# TODO: Fix RRR
# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
add_example_executable(example_gemm_mx_fp4_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4_bpreshuffle)
add_example_executable(example_moe_gemm1_xdl_mx_fp4_bns moe_gemm1_xdl_mx_fp4_bns.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns)
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns)
set(FP4_MXGEMM_OPTIONS)
list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1")
example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
set(FP8_MXGEMM_OPTIONS)
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})

View File

@@ -21,11 +21,11 @@ using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 128;
constexpr ck::index_t KPerBlock = 256;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
@@ -45,32 +45,32 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
ScaleBlockSize, // ScaleBlockSize: Scaling block size
128, // BlockSize: Thread block size
128, // MPerBlock
16, // NPerBlock
32, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
1, // NXdlPerWave
S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
2, // NXdlPerWave
S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
true, // ABlockLdsExtraM
S<16, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
@@ -83,6 +83,7 @@ int main(int argc, char* argv[])
ADataType,
BDataType,
XDataType,
XDataType,
CDataType,
ALayout,
BLayout,

View File

@@ -23,8 +23,9 @@
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using MFMA = ck::tensor_layout::gemm::MFMA;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
@@ -36,6 +37,8 @@ struct ExecutionConfig final
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
bool time_kernel = false; // (0=no, 1=yes)
int verbosity = 0; // (0=no info, 1=verbose info)
int warm_up = 10;
int repeat = 10;
};
struct ProblemSizeSplitK final
@@ -86,6 +89,8 @@ bool parse_cmd_args(int argc,
if(argc >= 12)
{
problem_size.KBatch = std::stoi(argv[11]);
config.warm_up = std::stoi(argv[12]);
config.repeat = std::stoi(argv[13]);
}
}
else
@@ -103,10 +108,90 @@ bool parse_cmd_args(int argc,
return true;
}
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
// 2-k)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
template <typename DeviceOpInstance,
typename ADataType,
typename BDataType,
typename XDataType,
typename XPackedDataType,
typename CDataType,
typename ALayout,
typename BLayout,
@@ -119,6 +204,8 @@ template <typename DeviceOpInstance,
ck::index_t ScaleBlockSize>
bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config)
{
constexpr bool BPreShuffle = ck::is_same_v<BLayout, MFMA>;
using BRefLayout = ck::conditional_t<BPreShuffle, Col, BLayout>;
auto M = problem_size.M;
auto N = problem_size.N;
@@ -131,28 +218,19 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto f_host_tensor_descriptor =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1});
}
else
{
return HostTensorDescriptor({row, col}, {1, stride});
}
};
auto f_get_default_stride =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<ck::index_t>(col);
}
else
{
return static_cast<ck::index_t>(row);
}
}
else
return static_cast<ck::index_t>(stride);
@@ -172,16 +250,30 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
using AScaleLayout = Row;
using BScaleLayout = Col;
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Padded_M = ck::math::integer_least_multiple(M, ScaleBlockSize);
auto Scale_Stride_AM =
f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
auto b_k_n =
std::make_shared<Tensor<BDataType>>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{}));
auto b_input = b_k_n;
if constexpr(BPreShuffle)
b_input = std::make_shared<Tensor<BDataType>>(
f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size
// scales for A and B
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
Tensor<XDataType> b_k_n_scale(f_host_tensor_descriptor(
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
Tensor<XDataType> b_k_n_scale(
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
// shuffled scales for A and B
Tensor<XDataType> a_shuffled_scale(f_host_tensor_descriptor(
Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
Tensor<XDataType> b_shuffled_scale(
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification
@@ -192,18 +284,33 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
{
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n->mDesc << std::endl;
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl;
}
auto a_data_element = [](float x) {
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
return ck::type_convert<ADataType>(ck::float2_t(x));
else
return ck::type_convert<ADataType>(x);
};
auto b_data_element = [](float x) {
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
return ck::type_convert<BDataType>(ck::float2_t(x));
else
return ck::type_convert<BDataType>(x);
};
using int_distr = std::uniform_int_distribution<int>;
using float_distr = std::uniform_real_distribution<float>;
switch(config.init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(0.5f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{b_data_element(2.0f)}(*b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
if(config.verbosity > 0)
{
std::cout << "Init A = {1}" << std::endl;
@@ -215,31 +322,19 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
if constexpr(ck::is_same_v<XDataType, ck::e8m0_bexp_t>)
{
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
}
else
{
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
}
a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2}
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0});
a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0});
b_k_n_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
break;
default:
@@ -249,20 +344,33 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
}
}
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
a_shuffled_scale.mData.data(),
Scale_Padded_M,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<BRefLayout, Col>>(
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
if constexpr(BPreShuffle)
{
int NPerXdl = 16; // Fixed 16
preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl);
}
if(config.verbosity > 0)
std::cout << "Device memory allocation..." << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize());
if(config.verbosity > 0)
std::cout << "Upload data to device..." << std::endl;
a_device_buf.ToDevice(a_m_k.mData.data());
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data());
b_device_buf.ToDevice(b_input->mData.data());
b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data());
if(config.verbosity > 0)
std::cout << "Done." << std::endl;
@@ -275,9 +383,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<XDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<XDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
@@ -299,13 +407,26 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
"not consistent with the supported device_gemm arguments.");
}
std::size_t total_size =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() +
a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() +
a_shuffled_scale.GetElementSpaceSizeInBytes() +
b_shuffled_scale.GetElementSpaceSizeInBytes();
const auto total_cnt = ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size);
const int rotating_count = std::max(1, std::min(config.repeat, static_cast<int>(total_cnt)));
if(config.verbosity > 0)
{
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
}
float ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
float ave_time = invoker.Run(argument,
StreamConfig{nullptr,
config.time_kernel,
config.verbosity,
config.warm_up,
config.repeat,
rotating_count > 1,
rotating_count});
bool res_verified = true;
if(config.do_verification > 0)
@@ -332,7 +453,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k_scale,
b_k_n,
*b_k_n,
b_k_n_scale,
c_m_n_host_result,
PassThrough{},
@@ -347,20 +468,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
std::cout << "Comparing results..." << std::endl;
}
if(config.init_method == 0)
{
auto expected = static_cast<float>(K);
auto computed = type_convert<float>(c_m_n_device_result(1, 12));
res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
<< std::endl;
}
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!");
res_verified =
res_verified &&
ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1);
if(config.verbosity > 0 && res_verified)
std::cout << "Verification Successful!" << std::endl;
@@ -377,13 +488,14 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
std::size_t num_btype =
sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> + sizeof(CDataType) * M * N +
sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
float gb_per_sec = static_cast<float>(num_btype) / 1e6f / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
@@ -396,6 +508,7 @@ template <typename DeviceOpInstance,
typename ADataType,
typename BDataType,
typename XDataType,
typename XPackedDataType,
typename CDataType,
typename ALayout,
typename BLayout,
@@ -416,6 +529,7 @@ bool run_mx_gemm_example(int argc, char* argv[])
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,

View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f4x2_pk_t;
using BDataType = ck::f4x2_pk_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
// AB DataType: f4x2_pk_t
// Mathmatically, all numbers are represented as f4x2.
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
256, // MPerBlock
256, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -0,0 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f4x2_pk_t;
using BDataType = ck::f4x2_pk_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = MFMA;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
// AB DataType: f4x2_pk_t
// Mathmatically, all numbers are represented as f4x2.
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
512, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle
S<1, 8, 1, 32>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlockW
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -25,7 +25,7 @@ constexpr ck::index_t KPerBlock = 256;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
@@ -49,26 +49,26 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
true, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
@@ -83,6 +83,7 @@ int main(int argc, char* argv[])
ADataType,
BDataType,
XDataType,
XDataType,
CDataType,
ALayout,
BLayout,

View File

@@ -24,7 +24,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
@@ -43,30 +43,30 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
256, // MPerBlock
256, // NPerBlock
128, // KPerBlock
128, // MPerBlock
128, // NPerBlock
256, // KPerBlock
16, // AK1
8, // BK1
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<32, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
@@ -82,6 +82,7 @@ int main(int argc, char* argv[])
ADataType,
BDataType,
XDataType,
XDataType,
CDataType,
ALayout,
BLayout,

View File

@@ -0,0 +1,545 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t BlockSize = 256;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, BlockSize,
MPerBlock, NPerBlock, KPerBlock,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
max_token_id.mData[0] = valid_size;
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_k_n_host_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_k_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
e_t_k_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{0.5f});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.5f});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k(token_id, k);
}
}
}
// A/B scale shuffle
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
b_scale_preshuffled.mData.data(),
N * 2 * experts,
K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop =
// FMA * tokens * N * (Gate+Up) * topk * K +
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
std::size_t(2) * tokens * N * 2 * topk * K +
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
sizeof(EDataType) * tokens * topk * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<CShuffleDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
XDataType,
B0DataType,
XDataType,
CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough,
ActOP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k,
a1_t_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n,
c_t_k_n,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < valid_size; ++m)
{
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
if(t >= tokens)
{
continue;
}
for(int n = 0; n < N; ++n)
{
e_t_k_n_host_result(t, topk_id, n) =
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
}
}
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
auto status =
ck::utils::check_err(
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
if(status == 0)
{
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

View File

@@ -0,0 +1,526 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, float, float, float, float>(
EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
for(int i = 0; i < sorted_tile_num; i++)
{
if(i < valid_tile_num)
{
eids[i] = (i * experts) / valid_tile_num;
}
else
{
eids[i] = 3;
}
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<XDataType> a1_t_k_k(
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
// B preshuffle
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
}
}
}
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// FMA * tokens * N * topk * K +
// FMA * tokens * N * topk * (K/BlockScale)
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<CShuffleDataType> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
XDataType,
B0DataType,
XDataType,
D2DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
CDEElementOp,
MulRoutedWeight,
float,
float>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k_k,
a1_t_k_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n, // topk weights
c_t_n,
PassThrough{},
PassThrough{},
cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
}
return 0;
}

View File

@@ -20,7 +20,7 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
endfunction(add_example_dependencies EXAMPLE_NAME)
function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
message(DEBUG "adding example ${EXAMPLE_NAME}")
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
@@ -47,7 +47,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
set(test 1)
endif()
if(test EQUAL 1)
message("removing example source file ${source} ")
message(DEBUG "removing example source file ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
@@ -58,70 +58,72 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
message(DEBUG "removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any DPP examples if DPP_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
message("removing dpp example ${source} ")
message(DEBUG "removing dpp example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
message(DEBUG "removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
message(DEBUG "removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any microscaling examples if gfx950 target is not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx")
message("removing microscaling example ${source} ")
message(DEBUG "removing microscaling example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
message("removing fp8 example ${source} ")
message(DEBUG "removing fp8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8")
message("removing bf8 example ${source} ")
message(DEBUG "removing bf8 example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
# Build fp8 gemm_multiply_multiply and moe only on gfx94/95
foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply_xdl_fp8_bpreshuffle")
message("Skipping ${source} example for current target")
list(REMOVE_ITEM FILE_NAME "${source}")
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95")
if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)")
message(DEBUG "Skipping ${source} example for current target")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950
message("trimming targets for ${FILE_NAME}")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
message(DEBUG "trimming targets for ${FILE_NAME}")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
@@ -133,7 +135,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
message(DEBUG "add_example returns ${result}")
if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES)
set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST")
add_dependencies(smoke ${EXAMPLE_NAME})
@@ -151,7 +153,7 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
endfunction(add_example_dependencies EXAMPLE_NAME)
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}")
message(DEBUG "adding example ${EXAMPLE_NAME}")
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS FILE_NAME)
@@ -178,7 +180,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
set(test 1)
endif()
if(test EQUAL 1)
message("removing example ${source} ")
message(DEBUG "removing example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
@@ -189,28 +191,28 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ")
message(DEBUG "removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
message(DEBUG "removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
message(DEBUG "removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
endif()
@@ -222,12 +224,18 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
message(DEBUG "add_example returns ${result}")
set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
function(example_compile_options EXAMPLE_NAME)
if(TARGET ${EXAMPLE_NAME})
target_compile_options(${EXAMPLE_NAME} ${ARGN})
endif()
endfunction(example_compile_options)
# add all example subdir
file(GLOB dir_list LIST_DIRECTORIES true *)
FOREACH(subdir ${dir_list})

View File

@@ -25,7 +25,7 @@ execute_process(
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
endif()
execute_process(
@@ -34,7 +34,7 @@ execute_process(
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
endif()
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
@@ -57,7 +57,7 @@ add_custom_command(
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_FWD}")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
@@ -65,7 +65,7 @@ target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}")
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})

View File

@@ -3,7 +3,7 @@
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
import fnmatch
import itertools
from pathlib import Path
@@ -117,8 +117,50 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
FMHA_FWD_API="""
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{
#include <cstdio>
namespace {{
bool get_num_cus(unsigned& num_cu) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cu = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
float r = -1;
const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
{F_dispatch}
return r;
}}
@@ -134,36 +176,50 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return 'true'
else:
return f'{self.bool_expr}'
def __and__(self, other):
return CppConstraint(f'({str(self)}) && ({str(other)})')
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
constraint : CppConstraint
@property
def name(self) -> str:
@@ -220,17 +276,18 @@ class FmhaFwdApiTrait:
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
@@ -297,8 +354,8 @@ class FmhaFwdApiPool:
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
@@ -313,25 +370,27 @@ class FmhaFwdApiPool:
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
@@ -423,33 +482,21 @@ class FmhaFwdKernel:
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
class KernelComponentFactory:
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'128' : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
@staticmethod
def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
@@ -458,53 +505,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
if bias == "bias":
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'128' : [FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),]
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
d = CustomFactory.get_hdim_tile_size_dict(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
tiles = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not

View File

@@ -58,7 +58,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
{F_occupancy},
{F_skip}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -94,7 +95,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
#include <iostream>
@@ -129,9 +130,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
@@ -160,11 +161,12 @@ class FmhaFwdApiTrait:
skpad : str
dpad : str
dvpad : str
skip : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
@property
def scheck(self) -> str:
@@ -227,6 +229,7 @@ class FmhaFwdPipeline:
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
@property
def name(self) -> str:
@@ -262,8 +265,12 @@ class FmhaFwdPipeline:
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_skip == 't' : n += '_skip'
else: n += '_nskip'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
return n
class FmhaFwdApiPool:
@@ -293,7 +300,7 @@ class FmhaFwdApiPool:
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
@@ -381,6 +388,7 @@ class FmhaFwdKernel:
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
@@ -419,7 +427,8 @@ class FmhaFwdKernel:
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@@ -453,36 +462,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
@@ -508,7 +517,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
@@ -532,6 +541,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# PyTorch integration
@@ -540,6 +550,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
@@ -565,6 +576,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -169,6 +169,7 @@ struct fmha_fwd_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
float p_drop;
bool s_randval;
@@ -433,6 +434,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.window_size_left,
args.window_size_right,
args.mask_type,
args.min_seqlen_q,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
@@ -837,7 +839,8 @@ template <ck_tile::index_t HDim_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -861,6 +864,7 @@ struct fmha_fwd_traits_
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_>
@@ -995,6 +999,7 @@ struct fmha_fwd_traits
bool has_lse;
bool has_dropout;
bool do_fp8_static_quant;
bool skip_min_seqlen_q = false;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);

View File

@@ -25,7 +25,7 @@ add_custom_command(
set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd")
message("adding example ${EXAMPLE_LAYERNORM2D_FWD}")
message(DEBUG "adding example ${EXAMPLE_LAYERNORM2D_FWD}")
add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp)
target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})

View File

@@ -18,9 +18,12 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
bool Persistent>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -214,4 +217,15 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -213,11 +213,20 @@ auto create_args(int argc, char* argv[])
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
bool Persistent = false>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -162,7 +162,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
int n_repeat,
bool persistent)
{
ck_tile::GemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
@@ -176,9 +177,31 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time =
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
float ave_time;
if(persistent)
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
true>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
else
{
ave_time = gemm_calc<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
false>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
@@ -193,8 +216,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
@@ -229,6 +252,7 @@ int run_gemm_example_with_layouts(int argc,
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool persistent = arg_parser.get_int("persistent");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -316,7 +340,8 @@ int run_gemm_example_with_layouts(int argc,
stride_C,
kbatch,
n_warmup,
n_repeat);
n_repeat,
persistent);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -11,19 +11,7 @@
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename Pipeline, ck_tile::TailNumber TN>
void try_run(ck_tile::TailNumber tn)
{
if constexpr(Pipeline::PrefetchStages > static_cast<int>(TN))
{
if(tn == TN)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, TN>{});
}
}
}
#include "run_gemm_example.inc"
template <typename ADataType,
typename BDataType,
@@ -31,7 +19,8 @@ template <typename ADataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename CLayout,
bool Persistent>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -60,7 +49,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity>;
GemmConfig::UseStructuredSparsity,
Persistent>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -74,64 +64,113 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto memory_operation = memory_operation_.value;
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
@@ -150,101 +189,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
}
};
if(has_hot_loop)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
if(tail_num == ck_tile::TailNumber::One)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
auto check_tail = [&](auto... TNs) {
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
};
check_tail(ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
{
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_gemm_example.inc"
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
@@ -345,7 +294,7 @@ int main(int argc, char* argv[])
{
try
{
run_gemm_example(argc, argv);
return !run_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{

View File

@@ -1,7 +1,7 @@
set(EXAMPLE_REDUCE "tile_example_reduce")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_REDUCE}")
message(DEBUG "adding example ${EXAMPLE_REDUCE}")
add_executable(${EXAMPLE_REDUCE} EXCLUDE_FROM_ALL reduce.cpp)
target_include_directories(${EXAMPLE_REDUCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})

View File

@@ -25,7 +25,7 @@ add_custom_command(
set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd")
message("adding ${TILE_RMSNORM2D_FWD}")
message(DEBUG "adding ${TILE_RMSNORM2D_FWD}")
add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp)
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})

View File

@@ -1,7 +1,7 @@
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp)
target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})

View File

@@ -1,5 +1,5 @@
function (add_smoothquant_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
message(DEBUG "adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})

View File

@@ -334,16 +334,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
try
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec == "fp32" && index_prec == "int32")
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
return r ? 0 : -1;
}

View File

@@ -1,5 +1,5 @@
function (add_moe_smoothquant_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
message(DEBUG "adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})

View File

@@ -1,7 +1,7 @@
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
message(DEBUG "adding ${TILE_EXAPMLE_FUSED_MOE}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})

View File

@@ -183,141 +183,22 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
}
};
if(has_hot_loop)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got "
<< tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
{
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
std::ostringstream err;
err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but "
"got "
<< tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_batched_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -197,121 +197,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
}
};
if(has_hot_loop)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Incorrect tail_num for compv3 pipeline! Expected Full, Odd or Even, but got "
<< tail_num << "\nPrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
}
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
{
std::ostringstream err;
err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but "
<< "got " << tail_num << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
@@ -319,4 +205,15 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
#include "run_grouped_gemm_example.inc"
constexpr bool Persistent = false;
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_grouped_gemm_example<Persistent>(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -3,6 +3,6 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -Wno-unused-local-typedef)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -Wno-unused-local-typedef)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

View File

@@ -11,6 +11,7 @@
#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"
#include "run_flatmm_example.inc"
template <typename ADataType,
typename BDataType,
@@ -21,49 +22,22 @@ template <typename ADataType,
typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 2;
// This part comes from the Codegen
#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16)
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 128;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8)
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 128;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 8;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
#endif
using CodegenFlatmmShape =
ck_tile::TileFlatmmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using FlatmmConfig = FlatmmConfig<ADataType>;
using CodegenFlatmmShape = ck_tile::TileFlatmmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenFlatmmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
@@ -81,11 +55,11 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation>>;
@@ -109,15 +83,57 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName()
<< CodegenPipelineProblem::GetName() << " grid: {" << grids.x << ", "
<< grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
float ave_time{0};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
if(args.k_batch == 1)
@@ -132,8 +148,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
}
}
#include "run_flatmm_example.inc"
int run_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -177,4 +191,15 @@ int run_flatmm_example(int argc, char* argv[])
return -1;
}
int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_flatmm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -109,6 +109,43 @@ struct is_8bit_type
{
};
template <typename ADataType>
struct FlatmmConfig
{
#if defined(USING_MFMA_16x16x32)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
#elif defined(USING_MFMA_32x32x16)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 8;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
#endif
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr int kBlockPerCu = 2;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -133,4 +170,11 @@ auto create_args(int argc, char* argv[])
}
// host API
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -32,38 +32,20 @@ static constexpr inline auto is_row_major(Layout layout_)
}
// mfma_type, 0:32x32, 1:16x16
template <typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type)
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({n_ / 32, 32, k_ / 16, 2, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({n_ / 16, 16, k_ / 32, 4, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({n_ / 32, 32, k_ / 32, 2, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({n_ / 16, 16, k_ / 64, 4, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
return t;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
@@ -122,7 +104,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
float ave_time =
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
@@ -149,10 +131,11 @@ int run_flatmm_example_with_layouts(int argc,
if(!result)
return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
using FlatmmConfig = FlatmmConfig<ADataType>;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
@@ -163,8 +146,9 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -188,13 +172,8 @@ int run_flatmm_example_with_layouts(int argc,
c_rslt_host.SetZero();
// do pre-shuffle
std::string mfma = arg_parser.get_str("prec");
#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8)
ck_tile::index_t mfma_type = 1;
#else
ck_tile::index_t mfma_type = 0;
#endif
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type);
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());

View File

@@ -85,6 +85,20 @@ struct FillUniformDistributionIntegerValue
}
};
/**
* @brief A functor for filling a container with a monotonically increasing or decreasing sequence.
*
* FillMonotonicSeq generates a sequence of values starting from an initial value
* and incrementing by a fixed step for each subsequent element.
*
* @tparam T The numeric type of the sequence elements.
*
* Example usage:
* ```
* std::vector<int> v(5);
* FillMonotonicSeq<int>{10, 2}(v); // Fills v with {10, 12, 14, 16, 18}
* ```
*/
template <typename T>
struct FillMonotonicSeq
{

View File

@@ -8,6 +8,7 @@
#include <iostream>
#include <fstream>
#include <numeric>
#include <random>
#include <thread>
#include <utility>
#include <vector>
@@ -18,6 +19,7 @@
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/ranges.hpp"
#include "ck/library/utility/thread.hpp"
template <typename Range>
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
@@ -512,6 +514,72 @@ struct Tensor
}
}
// Generate random values with multiple threads. Guaranteed to give the same sequence with any
// number of threads provided.
template <typename Distribution = std::uniform_real_distribution<float>,
typename Mapping = ck::identity,
typename Generator = std::minstd_rand>
void GenerateTensorDistr(Distribution dis = {0.f, 1.f},
Mapping fn = {},
const Generator g = Generator(0), // default seed 0
std::size_t num_thread = -1)
{
using ck::math::integer_divide_ceil;
using ck::math::min;
if(num_thread == -1ULL)
num_thread = min(ck::get_available_cpu_cores(), 80U); // max 80 threads
// At least 2MB per thread
num_thread = min(num_thread, integer_divide_ceil(this->GetElementSpaceSize(), 0x200000));
constexpr std::size_t BLOCK_BYTES = 64;
constexpr std::size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T);
const std::size_t num_blocks = integer_divide_ceil(this->GetElementSpaceSize(), BLOCK_SIZE);
const std::size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread);
std::vector<std::thread> threads;
threads.reserve(num_thread - 1);
const auto dst = const_cast<T*>(this->mData.data());
const auto element_space_size = this->GetElementSpaceSize();
for(int it = num_thread - 1; it >= 0; --it)
{
std::size_t ib_begin = it * blocks_per_thread;
std::size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks);
auto job = [=]() {
auto g_ = g; // copy
auto dis_ = dis; // copy
g_.discard(ib_begin * BLOCK_SIZE * ck::packed_size_v<T>);
auto t_fn = [&]() {
if constexpr(ck::packed_size_v<T> == 1)
return ck::type_convert<T>(fn(dis_(g_)));
else if constexpr(ck::is_same_v<T, ck::f4x2_pk_t>)
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(
ck::float2_t{ck::type_convert<float>(fn(dis_(g_))),
ck::type_convert<float>(fn(dis_(g_)))})};
else
static_assert(false, "Unsupported packed size for T");
};
std::size_t ib = ib_begin;
for(; ib < ib_end - 1; ++ib)
ck::static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) {
constexpr size_t iw = iw_.value;
dst[ib * BLOCK_SIZE + iw] = t_fn();
});
for(std::size_t iw = 0; iw < BLOCK_SIZE; ++iw)
if(ib * BLOCK_SIZE + iw < element_space_size)
dst[ib * BLOCK_SIZE + iw] = t_fn();
};
if(it > 0)
threads.emplace_back(std::move(job));
else
job(); // last job run in the main thread
}
for(auto& t : threads)
t.join();
}
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{

View File

@@ -163,6 +163,18 @@ struct GeneratorTensor_1<ck::pk_i4_t>
}
};
template <>
struct GeneratorTensor_1<ck::e8m0_bexp_t>
{
float value = 1;
template <typename... Is>
ck::e8m0_bexp_t operator()(Is...)
{
return ck::type_convert<ck::e8m0_bexp_t>(value);
}
};
template <typename T>
struct GeneratorTensor_2
{

View File

@@ -0,0 +1,25 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef __linux__
#include <sched.h>
#endif
#include <thread>
namespace ck {
inline unsigned int get_available_cpu_cores()
{
#if defined(__linux__)
cpu_set_t cpu_set;
if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0)
{
unsigned int cpu_count = CPU_COUNT(&cpu_set);
if(cpu_count > 0)
return cpu_count;
}
#endif
// Fallback if sched_getaffinity unavailable or fails
return std::thread::hardware_concurrency();
}
} // namespace ck

View File

@@ -35,6 +35,9 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
using ComputeTypeB = BDataType;
using AccType = float; // for now only support V_MFMA_SCALE_F32
static constexpr index_t APackedSize = packed_size_v<ComputeTypeA>;
static constexpr index_t BPackedSize = packed_size_v<ComputeTypeB>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
@@ -48,17 +51,24 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
// static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr index_t B_K1 =
BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, TransposeC, true>{};
static constexpr auto xdlops_gemm = XdlopsGemm<ComputeTypeA,
MPerXDL,
NPerXDL,
KPack * APackedSize,
ComputeTypeB,
TransposeC,
true>{};
static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack;
//> store rows/cols into thread registers in chunks of 16
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
static constexpr index_t KThreadChunk = 16;
static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA);
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
@@ -67,22 +77,29 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList =
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
A_K1,
B_K1,
A_K1,
B_K1,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops>;
// Hardcode to 2, for better 8-bit access pattern
static constexpr index_t MXdlPack = 2;
static constexpr index_t NXdlPack = 2;
static constexpr index_t KXdlPack = 2;
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< //
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
A_K1,
B_K1,
A_K1,
B_K1,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops,
(packed_size_v<ComputeTypeA> > 1 || packed_size_v<ComputeTypeB> > 1)>;
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
@@ -116,7 +133,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
return make_tuple(0, waveId_m, 0, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
@@ -127,7 +144,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
return make_tuple(0, waveId_n, 0, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
@@ -142,24 +159,27 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(
make_unmerge_transform(make_tuple(MRepeat / MXdlPack, MWaves, MXdlPack, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
make_tuple(Sequence<0, 1, 2, 3>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(
make_unmerge_transform(make_tuple(NRepeat / NXdlPack, NWaves, NXdlPack, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
make_tuple(Sequence<0, 1, 2, 3>{}));
// We pack 2 mfma in M/N direction, so we need to divide by 2
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
make_tuple(m0 / MXdlPack, waveId_m, m0 % MXdlPack, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
make_tuple(n0 / NXdlPack, waveId_n, n0 % NXdlPack, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base.
@@ -179,13 +199,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin = CalculateAThreadOriginDataIndex(),
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
@@ -221,6 +240,28 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// XDL output supporting C_xdl = A_xdl * B_xdl, packed mfma
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
I1,
I1,
Number<MXdlPack>{},
Number<NXdlPack>{},
M0,
M1,
M2,
N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
@@ -262,6 +303,23 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl_packed mfma
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MXdlPack>{},
Number<NXdlPack>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
@@ -314,45 +372,47 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k;
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<MRepeat / MXdlPack>{}, I1, Number<MXdlPack>{}, Number<KRepeat>{}, Number<KPack>{}));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<NRepeat / NXdlPack>{}, I1, Number<NXdlPack>{}, Number<KRepeat>{}, Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
Number<MXdlPack>{},
Number<NXdlPack>{},
xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_block_desc_m0_m1_m2_m3_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_block_desc_n0_n1_n2_n3_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp"
namespace ck {
@@ -29,7 +30,29 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t KPack>
constexpr auto BlockGemmPipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmWmmaops_pipeline_v1<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmWmmaops_pipeline_v3<BlkGemmPipeSche,
BlockSize,

View File

@@ -61,7 +61,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
static constexpr auto wmma_gemm =
WmmaGemm<ADataType, BDataType, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
WmmaGemm<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
static constexpr index_t KRepeat = KPerBlock / KPack;
@@ -198,7 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize");
static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 &&
NPerBlock % (NPerWmma * NRepeat) == 0,
@@ -257,10 +257,10 @@ struct BlockwiseGemmWmmaops_pipeline_base
Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<KPack * A_K1>{},
Number<A_K1>{},
Number<A_K1>{},
Number<1>{}));
Number<KPack / A_KRow * MRepeat>{},
I0,
I0,
I1));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{},
@@ -271,10 +271,10 @@ struct BlockwiseGemmWmmaops_pipeline_base
Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<KPack * B_K1>{},
Number<B_K1>{},
Number<B_K1>{},
Number<1>{}));
Number<KPack / B_KRow * NRepeat>{},
I0,
I0,
I1));
// C[M, N, NumRegWmma]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
@@ -282,10 +282,10 @@ struct BlockwiseGemmWmmaops_pipeline_base
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<ADataType,
ADataType,
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<KPack / A_K1 / A_KRow, MRepeat, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
@@ -293,10 +293,10 @@ struct BlockwiseGemmWmmaops_pipeline_base
using BThreadCopy =
ThreadwiseTensorSliceTransfer_v4<BDataType,
BDataType,
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<KPack / B_K1 / B_KRow, NRepeat, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,

View File

@@ -0,0 +1,638 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeTypeA,
typename ComputeTypeB,
typename AccDataType,
typename AWmmaTileDesc,
typename BWmmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmWmmaops_pipeline_v1
{
};
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeTypeA,
typename ComputeTypeB,
typename AccDataType,
typename AWmmaTileDesc,
typename BWmmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::A_K1;
using Base::A_KRow;
using Base::B_K1;
using Base::B_KRow;
using Base::KRepeat;
using Base::WmmaK;
using Base::wmma_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
using Base::GetCThreadBuffer;
using Base::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
using Base::a_block_desc_k0_m0_m1_m2_k1;
using Base::b_block_desc_k0_n0_n1_n2_k1;
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
static TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Initialize C
c_thread_buf.Clear();
auto blockwise_gemm_func = [&]() {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
};
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
blockwise_gemm_func();
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
blockwise_gemm_func();
}
}
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
using Base::b_thread_desc_;
using Base::c_thread_desc_;
};
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeTypeA,
typename ComputeTypeB,
typename AccDataType,
typename AWmmaTileDesc,
typename BWmmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeTypeA,
ComputeTypeB,
AccDataType,
AWmmaTileDesc,
BWmmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::I1;
using Base::A_K1;
using Base::A_KRow;
using Base::B_K1;
using Base::B_KRow;
using Base::KRepeat;
using Base::WmmaK;
using Base::wmma_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
using Base::GetCThreadBuffer;
using Base::
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
using Base::a_block_desc_k0_m0_m1_m2_k1;
using Base::b_block_desc_k0_n0_n1_n2_k1;
static constexpr index_t NumKClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
static TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Initialize C
c_thread_buf.Clear();
auto blockwise_gemm_func = [&]() {
static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
I0,
I0,
I0,
I0,
I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0_inner, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
I0,
I0,
I0,
I0,
I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0_inner, I0, I0, I0),
b_thread_buf);
});
__builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
// but except the first, as we can shorten non-MAC cluster a bit and there's no
// observable negative impact. The desired effect is waves in a workgroup
// executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
// resource from other workgroups and reducing the chance of latency hiding by
// waiting for the rest of the workgroup at the eventual sync point.
if constexpr(k0_offset != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<ik / A_K1>{},
m0,
k0_inner,
I0,
I0,
Number<ik % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<ik / B_K1>{},
n0,
k0_inner,
I0,
I0,
Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard.
// B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts.
// It is performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
n0 == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
});
};
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
blockwise_gemm_func();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
blockwise_gemm_func();
}
}
protected:
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / A_K1 / A_KRow>{},
Number<MRepeat>{},
Number<KRepeatPerCluster>{},
I1,
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<KPack / A_KRow * MRepeat>{},
I0,
I0,
I1));
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor(make_tuple(Number<KPack / B_K1 / B_KRow>{},
Number<NRepeat>{},
Number<KRepeatPerCluster>{},
I1,
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<KPack / B_KRow * NRepeat>{},
I0,
I0,
I1));
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, MRepeat, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
using BThreadCopy =
ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, NRepeat, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
using Base::c_thread_desc_;
};
} // namespace ck

View File

@@ -315,24 +315,18 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
});
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
b_thread_buf);
});
__builtin_amdgcn_sched_barrier(0);
@@ -363,12 +357,22 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(ik / A_K1, m0, k0, 0, 0, ik % A_K1))>{}];
make_tuple(Number<ik / A_K1>{},
m0,
k0,
I0,
I0,
Number<ik % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(ik / B_K1, n0, k0, 0, 0, ik % B_K1))>{}];
make_tuple(Number<ik / B_K1>{},
n0,
k0,
I0,
I0,
Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
@@ -377,7 +381,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
@@ -389,24 +393,20 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
});
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
b_thread_buf);
});
HotLoopScheduler();
@@ -426,13 +426,13 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(ik / A_K1, m0, k0, 0, 0, ik % A_K1))>{}];
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(ik / B_K1, n0, k0, 0, 0, ik % B_K1))>{}];
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
@@ -441,7 +441,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),

View File

@@ -145,7 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
using Base::MWaves;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, BDataType>{};
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType>{};
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;

View File

@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -153,9 +154,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -290,12 +291,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
// B VGPR->VGPR dequant
@@ -388,12 +391,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
// B VGPR->VGPR dequant
@@ -477,12 +483,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
// B VGPR->VGPR dequant
@@ -588,7 +596,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -154,9 +155,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -298,12 +299,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -382,12 +385,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -458,12 +464,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, k0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
@@ -556,7 +565,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,

View File

@@ -0,0 +1,952 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace ck {
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3
{
};
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::A_K1;
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
using Base::a_block_desc_m0_m1_m2_k;
using Base::CalculateCThreadOriginDataIndex;
using Base::CalculateCThreadOriginDataIndex8D;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::MWaves;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
__device__ static constexpr auto HotLoopScheduler()
{
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_total_stages = MRepeat;
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_stages_more =
(num_buffer_load_inst_a + num_buffer_load_inst_b) -
math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
(num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_b_stages =
buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
? num_buffer_load_inst_b / buffer_load_perstage_more
: (buffer_load_stages_more +
(num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
buffer_load_perstage_less);
constexpr auto buffer_load_a_stages =
num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_b)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_b)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// A global read + A local write
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
ds_write_issue_point)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
ds_write_issue_point)))
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
}
template <typename Stage>
__device__ static constexpr auto EpilogueScheduler_1(Stage stage)
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b =
MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
if constexpr(stage.value == 0)
{
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
{
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
});
__builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
}
__device__ static constexpr auto EpilogueScheduler_2()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
BBlockTransfer& b_blockwise_copy,
BBlockTransfer& b_blockwise_copy_up,
const BGridBuffer& b_grid_buf,
const BGridBuffer& b_grid_buf_up,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
CThreadBuffer& c_thread_buf_up,
index_t num_loop) const
{
ignore = b_block_buf;
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
// Global prefetch A1 B1
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
// // Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1
block_sync_lds();
static_for<0, 2, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
c_thread_buf_up.Clear();
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
2,
I0,
I0,
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
if constexpr(m0.value == MRepeat - 2)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
});
HotLoopScheduler();
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else if constexpr(m0.value == MRepeat - 1)
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
});
HotLoopScheduler();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
});
HotLoopScheduler();
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec_up;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
});
}
}
protected:
// MRepeat MWave MLane KRepeat KLane KPack
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
// Reduce the vgpr usage here.
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(I2, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_;
};
} // namespace ck

View File

@@ -0,0 +1,919 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1
{
};
template <index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
BlockGemmPipelineScheduler::Intrawave,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack> : BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
using Base::xdlops_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetWaveIdx;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::KThreadChunk;
using Base::APackedSize;
using Base::BPackedSize;
using Base::ComputePackedSize;
using AccType = typename Base::AccType;
using Tuple4 = typename Base::Tuple4;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
static constexpr auto ScalesPerKBlockSize =
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleGridBuffer,
typename AScaleGridDesc,
typename AScaleThreadTransfer,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadTransfer>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
BBlockTransfer& b_blockwise_copy_up,
const BGridBuffer& b_grid_buf,
const BGridBuffer& b_grid_buf_up,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
CThreadBuffer& c_thread_buf_up,
// A and B scales
const AScaleGridDesc& a_scale_grid_desc,
AScaleThreadTransfer& a_scale_thread_copy,
const AScaleGridBuffer& a_scale_grid_buf,
const BScaleGridDesc& b_scale_grid_desc,
BScaleThreadTransfer& b_scale_thread_copy,
BScaleThreadTransfer& b_scale_thread_copy_up,
const BScaleGridBuffer& b_scale_grid_buf,
const BScaleGridBuffer& b_scale_grid_buf_up,
index_t num_loop) const
{
ignore = b_block_desc;
ignore = b_block_buf;
ignore = a_scale_grid_buf;
ignore = b_scale_grid_buf;
ignore = b_scale_grid_buf_up;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs_up;
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales to buf 0
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(I0));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales to buf 0
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
auto b_scale_thread_buf_copy_up =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy_up);
b_scale_thread_bufs_up(I0)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy_up[Number<0>{}];
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
__builtin_amdgcn_sched_barrier(0);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Prefetch a_scales to buf 1
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(I1));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales to buf 1
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
auto b_scale_thread_buf_copy_up =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy_up);
b_scale_thread_bufs_up(I1)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy_up[Number<0>{}];
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
// Local prefetch A1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
c_thread_buf_up.Clear();
// main body
if constexpr(HasMainLoop)
{
// loop over k with the step KPerBlock
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(
0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec_up;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[mfma_reg_buf]
[Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[mfma_reg_buf]
[Number<b_scale_offset + s>{}];
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up[mfma_reg_buf]
[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
// a thread copy
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk *
xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Prefetch a_scales
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(mfma_reg_buf));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
auto b_scale_thread_buf_copy_up =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy_up);
b_scale_thread_bufs_up(mfma_reg_buf)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy_up[Number<0>{}];
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec_up;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
// a thread copy
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec_up;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up[I1][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec_up;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from a_scale_grid to a_scale_thread
static constexpr auto a_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from b_scale_grid to b_scale_thread_buf
static constexpr auto b_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
protected:
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
// using Base::b_thread_desc_;
using Base::c_thread_desc_;
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
};
} // namespace ck

View File

@@ -0,0 +1,155 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
// must be used for compute
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool GUFusion = false>
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
{
// Hardware MX GEMM pipeline
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
;
}
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else
{
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
}
}
} // namespace ck

View File

@@ -0,0 +1,813 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
{
};
template <index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
using Base::xdlops_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetWaveIdx;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::KThreadChunk;
using Base::APackedSize;
using Base::BPackedSize;
using Base::ComputePackedSize;
using AccType = typename Base::AccType;
using Tuple4 = typename Base::Tuple4;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
static constexpr auto ScalesPerKBlockSize =
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleGridBuffer,
typename AScaleGridDesc,
typename AScaleThreadTransfer,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadTransfer>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
// A and B scales
const AScaleGridDesc& a_scale_grid_desc,
AScaleThreadTransfer& a_scale_thread_copy,
const AScaleGridBuffer& a_scale_grid_buf,
const BScaleGridDesc& b_scale_grid_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
index_t num_loop) const
{
ignore = b_block_desc;
ignore = b_block_buf;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_buf(I0)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
// Prefetch b_scales to buf 0
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
__builtin_amdgcn_sched_barrier(0);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Prefetch a_scales to buf 1
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_buf(I1)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
// Prefetch b_scales to buf 1
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
// Local prefetch A1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
// loop over k with the step KPerBlock
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(
0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[mfma_reg_buf]
[Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[mfma_reg_buf]
[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
// a thread copy
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk *
xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Prefetch a_scales
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(mfma_reg_buf));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
// a thread copy
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from a_scale_grid to a_scale_thread
static constexpr auto a_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from b_scale_grid to b_scale_thread_buf
static constexpr auto b_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
protected:
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
// using Base::b_thread_desc_;
using Base::c_thread_desc_;
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
};
} // namespace ck

View File

@@ -8,6 +8,7 @@
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp"
@@ -171,26 +172,54 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
if constexpr(std::is_same<ADataType, BDataType>::value)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else
{

View File

@@ -270,10 +270,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// // Global prefetch A2
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);

View File

@@ -58,11 +58,21 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t KGroup =
((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
? 2
: 1;
static constexpr index_t KGroup = []() {
if constexpr(is_same_v<remove_cvref_t<ComputeDataType>, f8_t>)
// On gfx950, we have mfma that required 32 f8 elements as input,
// splited into 2 groups of 16 f8 elements.
// the 2 groups is not contiguous in the B preshuffed layout.
// and we do not want it to be contiguous in the B preshuffled layout
// because a memory instruction can only read 16 f8 elements at a time.
return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
? 2
: 1;
else
return 1;
}();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);

View File

@@ -0,0 +1,123 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp"
namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MScaleBlock,
index_t NScaleBlock,
index_t KScaleBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmBlockScaleBPreshufflePipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
#if 0
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v2<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
#endif
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
}
}
} // namespace ck

View File

@@ -0,0 +1,864 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace ck {
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MScaleBlock,
index_t NScaleBlock,
index_t KScaleBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1
{
};
template <index_t BlockSize,
typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MScaleBlock,
index_t NScaleBlock,
index_t KScaleBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack,
true>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack,
true>;
using Base::A_K1;
using Base::B_K1;
using Base::I0;
using Base::I1;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
using Base::a_block_desc_m0_m1_m2_k;
using Base::CalculateCThreadOriginDataIndex;
using Base::CalculateCThreadOriginDataIndex8D;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MWaves;
using Base::NWaves;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
__host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
__device__ static constexpr auto HotLoopScheduler()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
// A local
static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read
});
}
template <bool HasMainLoop,
int NumKBlockPerScale,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CScaleThreadDesc,
typename CThreadBuffer,
typename AScaleGridBuffer,
typename AScaleGridDesc,
typename AScaleThreadDesc,
typename AScaleThreadTransfer,
typename AScaleThreadTransferStep,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadDesc,
typename BScaleThreadTransfer,
typename BScaleThreadTransferStep>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
// CThread
const CScaleThreadDesc& c_scale_thread_desc,
CThreadBuffer& c_thread_buf,
// AScaleThreadCopy
const AScaleGridDesc& a_scale_grid_desc,
const AScaleThreadDesc& a_scale_thread_desc,
AScaleThreadTransfer& a_scale_thread_copy,
const AScaleGridBuffer& a_scale_grid_buf,
const AScaleThreadTransferStep& a_scale_thread_copy_step,
// BScaleThreadCopy
const BScaleGridDesc& b_scale_grid_desc,
const BScaleThreadDesc& b_scale_thread_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
const BScaleThreadTransferStep& b_scale_thread_copy_step,
// num_loop
index_t num_loop) const
{
ignore = b_block_desc;
ignore = b_block_buf;
// __builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
b_scale_thread_desc.GetElementSpaceSize());
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
c_scale_thread_desc.GetElementSpaceSize());
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
__builtin_amdgcn_sched_barrier(0);
constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
});
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
1,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
// Local prefetch A1
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
// __builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset =
CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block +
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<
b_thread_desc_.CalculateOffset(make_tuple(
n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}(
[&](auto t) {
using pk_fma_type =
typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) =
__builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec
.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf
.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
});
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step);
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_buf(Number<c_offset>{}) =
a_scale_thread_buf[Number<a_offset>{}] *
b_scale_thread_buf[Number<b_offset>{}];
});
});
});
block_sync_lds();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * KPack / KGroup>{}),
a_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
vector_type<AccDataType, 2> c_scale_thread_vec;
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(
make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat));
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[Number<cscale_offset>{}];
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});
});
}
}
protected:
// MRepeat MWave MLane KRepeat KLane KPack
// KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
A_K1>;
AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_;
};
} // namespace ck

Some files were not shown because too many files have changed in this diff Show More