mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Merge remote-tracking branch 'origin/develop' into gfx950-mxfp4
This commit is contained in:
@@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for FP16 2:4 structured sparsity to universal GEMM.
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
* Added logit soft-capping support for fMHA forward kernels.
|
||||
* Added benchmarking support for tile engine GEMM.
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
288
CMakeLists.txt
288
CMakeLists.txt
@@ -26,6 +26,10 @@ set(version 1.1.0)
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
|
||||
include(CTest)
|
||||
|
||||
option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
|
||||
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
|
||||
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
|
||||
# CK Codegen requires dataclass which is added in Python 3.7
|
||||
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
|
||||
@@ -390,146 +394,152 @@ else()
|
||||
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
|
||||
endif()
|
||||
|
||||
## tidy
|
||||
include(EnableCompilerWarnings)
|
||||
## tidy
|
||||
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
|
||||
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
|
||||
set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
|
||||
set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
|
||||
# Enable tidy on hip
|
||||
elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU")
|
||||
set(CK_TIDY_ERRORS ALL)
|
||||
set(CK_TIDY_ERRORS ALL)
|
||||
endif()
|
||||
|
||||
|
||||
include(ClangTidy)
|
||||
enable_clang_tidy(
|
||||
CHECKS
|
||||
*
|
||||
-abseil-*
|
||||
-android-cloexec-fopen
|
||||
# Yea we shouldn't be using rand()
|
||||
-cert-msc30-c
|
||||
-bugprone-exception-escape
|
||||
-bugprone-macro-parentheses
|
||||
-cert-env33-c
|
||||
-cert-msc32-c
|
||||
-cert-msc50-cpp
|
||||
-cert-msc51-cpp
|
||||
-cert-dcl37-c
|
||||
-cert-dcl51-cpp
|
||||
-clang-analyzer-alpha.core.CastToStruct
|
||||
-clang-analyzer-optin.performance.Padding
|
||||
-clang-diagnostic-deprecated-declarations
|
||||
-clang-diagnostic-extern-c-compat
|
||||
-clang-diagnostic-unused-command-line-argument
|
||||
-cppcoreguidelines-avoid-c-arrays
|
||||
-cppcoreguidelines-avoid-magic-numbers
|
||||
-cppcoreguidelines-explicit-virtual-functions
|
||||
-cppcoreguidelines-init-variables
|
||||
-cppcoreguidelines-macro-usage
|
||||
-cppcoreguidelines-non-private-member-variables-in-classes
|
||||
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
|
||||
-cppcoreguidelines-pro-bounds-constant-array-index
|
||||
-cppcoreguidelines-pro-bounds-pointer-arithmetic
|
||||
-cppcoreguidelines-pro-type-member-init
|
||||
-cppcoreguidelines-pro-type-reinterpret-cast
|
||||
-cppcoreguidelines-pro-type-union-access
|
||||
-cppcoreguidelines-pro-type-vararg
|
||||
-cppcoreguidelines-special-member-functions
|
||||
-fuchsia-*
|
||||
-google-explicit-constructor
|
||||
-google-readability-braces-around-statements
|
||||
-google-readability-todo
|
||||
-google-runtime-int
|
||||
-google-runtime-references
|
||||
-hicpp-vararg
|
||||
-hicpp-braces-around-statements
|
||||
-hicpp-explicit-conversions
|
||||
-hicpp-named-parameter
|
||||
-hicpp-no-array-decay
|
||||
# We really shouldn't use bitwise operators with signed integers, but
|
||||
# opencl leaves us no choice
|
||||
-hicpp-avoid-c-arrays
|
||||
-hicpp-signed-bitwise
|
||||
-hicpp-special-member-functions
|
||||
-hicpp-uppercase-literal-suffix
|
||||
-hicpp-use-auto
|
||||
-hicpp-use-equals-default
|
||||
-hicpp-use-override
|
||||
-llvm-header-guard
|
||||
-llvm-include-order
|
||||
#-llvmlibc-*
|
||||
-llvmlibc-restrict-system-libc-headers
|
||||
-llvmlibc-callee-namespace
|
||||
-llvmlibc-implementation-in-namespace
|
||||
-llvm-else-after-return
|
||||
-llvm-qualified-auto
|
||||
-misc-misplaced-const
|
||||
-misc-non-private-member-variables-in-classes
|
||||
-misc-no-recursion
|
||||
-modernize-avoid-bind
|
||||
-modernize-avoid-c-arrays
|
||||
-modernize-pass-by-value
|
||||
-modernize-use-auto
|
||||
-modernize-use-default-member-init
|
||||
-modernize-use-equals-default
|
||||
-modernize-use-trailing-return-type
|
||||
-modernize-use-transparent-functors
|
||||
-performance-unnecessary-value-param
|
||||
-readability-braces-around-statements
|
||||
-readability-else-after-return
|
||||
# we are not ready to use it, but very useful
|
||||
-readability-function-cognitive-complexity
|
||||
-readability-isolate-declaration
|
||||
-readability-magic-numbers
|
||||
-readability-named-parameter
|
||||
-readability-uppercase-literal-suffix
|
||||
-readability-convert-member-functions-to-static
|
||||
-readability-qualified-auto
|
||||
-readability-redundant-string-init
|
||||
# too many narrowing conversions in our code
|
||||
-bugprone-narrowing-conversions
|
||||
-cppcoreguidelines-narrowing-conversions
|
||||
-altera-struct-pack-align
|
||||
-cppcoreguidelines-prefer-member-initializer
|
||||
${CK_TIDY_CHECKS}
|
||||
${CK_TIDY_ERRORS}
|
||||
HEADER_FILTER
|
||||
"\.hpp$"
|
||||
EXTRA_ARGS
|
||||
-DCK_USE_CLANG_TIDY
|
||||
)
|
||||
if(ENABLE_CLANG_CPP_CHECKS)
|
||||
include(ClangTidy)
|
||||
enable_clang_tidy(
|
||||
CHECKS
|
||||
*
|
||||
-abseil-*
|
||||
-android-cloexec-fopen
|
||||
# Yea we shouldn't be using rand()
|
||||
-cert-msc30-c
|
||||
-bugprone-exception-escape
|
||||
-bugprone-macro-parentheses
|
||||
-cert-env33-c
|
||||
-cert-msc32-c
|
||||
-cert-msc50-cpp
|
||||
-cert-msc51-cpp
|
||||
-cert-dcl37-c
|
||||
-cert-dcl51-cpp
|
||||
-clang-analyzer-alpha.core.CastToStruct
|
||||
-clang-analyzer-optin.performance.Padding
|
||||
-clang-diagnostic-deprecated-declarations
|
||||
-clang-diagnostic-extern-c-compat
|
||||
-clang-diagnostic-unused-command-line-argument
|
||||
-cppcoreguidelines-avoid-c-arrays
|
||||
-cppcoreguidelines-avoid-magic-numbers
|
||||
-cppcoreguidelines-explicit-virtual-functions
|
||||
-cppcoreguidelines-init-variables
|
||||
-cppcoreguidelines-macro-usage
|
||||
-cppcoreguidelines-non-private-member-variables-in-classes
|
||||
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
|
||||
-cppcoreguidelines-pro-bounds-constant-array-index
|
||||
-cppcoreguidelines-pro-bounds-pointer-arithmetic
|
||||
-cppcoreguidelines-pro-type-member-init
|
||||
-cppcoreguidelines-pro-type-reinterpret-cast
|
||||
-cppcoreguidelines-pro-type-union-access
|
||||
-cppcoreguidelines-pro-type-vararg
|
||||
-cppcoreguidelines-special-member-functions
|
||||
-fuchsia-*
|
||||
-google-explicit-constructor
|
||||
-google-readability-braces-around-statements
|
||||
-google-readability-todo
|
||||
-google-runtime-int
|
||||
-google-runtime-references
|
||||
-hicpp-vararg
|
||||
-hicpp-braces-around-statements
|
||||
-hicpp-explicit-conversions
|
||||
-hicpp-named-parameter
|
||||
-hicpp-no-array-decay
|
||||
# We really shouldn't use bitwise operators with signed integers, but
|
||||
# opencl leaves us no choice
|
||||
-hicpp-avoid-c-arrays
|
||||
-hicpp-signed-bitwise
|
||||
-hicpp-special-member-functions
|
||||
-hicpp-uppercase-literal-suffix
|
||||
-hicpp-use-auto
|
||||
-hicpp-use-equals-default
|
||||
-hicpp-use-override
|
||||
-llvm-header-guard
|
||||
-llvm-include-order
|
||||
#-llvmlibc-*
|
||||
-llvmlibc-restrict-system-libc-headers
|
||||
-llvmlibc-callee-namespace
|
||||
-llvmlibc-implementation-in-namespace
|
||||
-llvm-else-after-return
|
||||
-llvm-qualified-auto
|
||||
-misc-misplaced-const
|
||||
-misc-non-private-member-variables-in-classes
|
||||
-misc-no-recursion
|
||||
-modernize-avoid-bind
|
||||
-modernize-avoid-c-arrays
|
||||
-modernize-pass-by-value
|
||||
-modernize-use-auto
|
||||
-modernize-use-default-member-init
|
||||
-modernize-use-equals-default
|
||||
-modernize-use-trailing-return-type
|
||||
-modernize-use-transparent-functors
|
||||
-performance-unnecessary-value-param
|
||||
-readability-braces-around-statements
|
||||
-readability-else-after-return
|
||||
# we are not ready to use it, but very useful
|
||||
-readability-function-cognitive-complexity
|
||||
-readability-isolate-declaration
|
||||
-readability-magic-numbers
|
||||
-readability-named-parameter
|
||||
-readability-uppercase-literal-suffix
|
||||
-readability-convert-member-functions-to-static
|
||||
-readability-qualified-auto
|
||||
-readability-redundant-string-init
|
||||
# too many narrowing conversions in our code
|
||||
-bugprone-narrowing-conversions
|
||||
-cppcoreguidelines-narrowing-conversions
|
||||
-altera-struct-pack-align
|
||||
-cppcoreguidelines-prefer-member-initializer
|
||||
${CK_TIDY_CHECKS}
|
||||
${CK_TIDY_ERRORS}
|
||||
HEADER_FILTER
|
||||
"\.hpp$"
|
||||
EXTRA_ARGS
|
||||
-DCK_USE_CLANG_TIDY
|
||||
)
|
||||
|
||||
include(CppCheck)
|
||||
enable_cppcheck(
|
||||
CHECKS
|
||||
warning
|
||||
style
|
||||
performance
|
||||
portability
|
||||
SUPPRESS
|
||||
ConfigurationNotChecked
|
||||
constStatement
|
||||
duplicateCondition
|
||||
noExplicitConstructor
|
||||
passedByValue
|
||||
preprocessorErrorDirective
|
||||
shadowVariable
|
||||
unusedFunction
|
||||
unusedPrivateFunction
|
||||
unusedStructMember
|
||||
unmatchedSuppression
|
||||
FORCE
|
||||
SOURCES
|
||||
library/src
|
||||
INCLUDE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/library/include
|
||||
DEFINE
|
||||
CPPCHECK=1
|
||||
__linux__=1
|
||||
)
|
||||
include(CppCheck)
|
||||
enable_cppcheck(
|
||||
CHECKS
|
||||
warning
|
||||
style
|
||||
performance
|
||||
portability
|
||||
SUPPRESS
|
||||
ConfigurationNotChecked
|
||||
constStatement
|
||||
duplicateCondition
|
||||
noExplicitConstructor
|
||||
passedByValue
|
||||
preprocessorErrorDirective
|
||||
shadowVariable
|
||||
unusedFunction
|
||||
unusedPrivateFunction
|
||||
unusedStructMember
|
||||
unmatchedSuppression
|
||||
FORCE
|
||||
SOURCES
|
||||
library/src
|
||||
INCLUDE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/library/include
|
||||
DEFINE
|
||||
CPPCHECK=1
|
||||
__linux__=1
|
||||
)
|
||||
else()
|
||||
function(clang_tidy_check TARGET)
|
||||
# stub out empty function if clang tidy is not enabled
|
||||
endfunction()
|
||||
endif()
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
@@ -557,12 +567,15 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS
|
||||
add_compile_options(-fdiagnostics-color=always)
|
||||
endif()
|
||||
|
||||
# make check runs the entire set of examples and tests
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
|
||||
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
|
||||
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
|
||||
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
|
||||
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
|
||||
if(NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
# make check runs the entire set of examples and tests
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
|
||||
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
|
||||
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
|
||||
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
|
||||
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
|
||||
@@ -607,6 +620,7 @@ ENDFOREACH()
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
|
||||
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
|
||||
add_subdirectory(library)
|
||||
|
||||
137
Jenkinsfile
vendored
137
Jenkinsfile
vendored
@@ -114,6 +114,9 @@ def check_arch(){
|
||||
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
|
||||
arch_type = 6
|
||||
}
|
||||
else if ( runShell('grep -n "gfx950" rocminfo.log') ) {
|
||||
arch_type = 7
|
||||
}
|
||||
return arch_type
|
||||
}
|
||||
|
||||
@@ -132,6 +135,10 @@ def getDockerImage(Map conf=[:]){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using special docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
@@ -208,6 +215,11 @@ def cmake_build(Map conf=[:]){
|
||||
|
||||
def build_type_debug = (conf.get("build_type",'release') == 'debug')
|
||||
|
||||
// use special compiler for gfx950
|
||||
if ( check_arch() == 7){
|
||||
compiler = "/llvm-project/build/bin/clang++"
|
||||
}
|
||||
|
||||
//cmake_env can overwrite default CXX variables.
|
||||
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
|
||||
|
||||
@@ -263,6 +275,9 @@ def cmake_build(Map conf=[:]){
|
||||
if (setup_args.contains("gfx94")){
|
||||
invocation_tag="gfx94"
|
||||
}
|
||||
if (setup_args.contains("gfx95")){
|
||||
invocation_tag="gfx95"
|
||||
}
|
||||
echo "invocation tag: ${invocation_tag}"
|
||||
def redis_pre_setup_cmd = pre_setup_cmd
|
||||
if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") {
|
||||
@@ -422,16 +437,6 @@ def buildHipClangJob(Map conf=[:]){
|
||||
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
checkout scm
|
||||
|
||||
def image
|
||||
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
}
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
@@ -455,7 +460,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
echo "Docker flags: ${dockerOpts}"
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
|
||||
def image
|
||||
def retimage
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
|
||||
@@ -496,17 +501,6 @@ def Build_CK(Map conf=[:]){
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
env.DOCKER_BUILDKIT=1
|
||||
checkout scm
|
||||
|
||||
def image
|
||||
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
|
||||
image = conf.get("docker_name", "")
|
||||
echo "Using legacy docker: ${image}"
|
||||
}
|
||||
else{
|
||||
image = getDockerImageName()
|
||||
echo "Using default docker: ${image}"
|
||||
}
|
||||
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
@@ -527,6 +521,7 @@ def Build_CK(Map conf=[:]){
|
||||
echo "Docker flags: ${dockerOpts}"
|
||||
|
||||
def variant = env.STAGE_NAME
|
||||
def image
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
@@ -638,6 +633,13 @@ def Build_CK(Map conf=[:]){
|
||||
archiveArtifacts "perf_onnx_gemm_gfx908.log"
|
||||
stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908"
|
||||
}
|
||||
else if ( arch == 7 ){
|
||||
// run basic tests on gfx950
|
||||
echo "Run performance tests"
|
||||
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx950.log"
|
||||
stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950"
|
||||
}
|
||||
}
|
||||
}
|
||||
if (params.hipTensor_test && arch == 1 ){
|
||||
@@ -774,8 +776,8 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
@@ -848,8 +850,8 @@ pipeline {
|
||||
description: "Run the grouped conv large cases tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CODEGEN_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run codegen tests (default: OFF)")
|
||||
defaultValue: true,
|
||||
description: "Run codegen tests (default: ON)")
|
||||
booleanParam(
|
||||
name: "RUN_CK_TILE_FMHA_TESTS",
|
||||
defaultValue: false,
|
||||
@@ -862,6 +864,10 @@ pipeline {
|
||||
name: "RUN_CK_TILE_GEMM_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the ck_tile GEMM tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_TILE_ENGINE_GEMM_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the tile_engine_gemm tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_INSTANCES_ONLY",
|
||||
defaultValue: false,
|
||||
@@ -870,6 +876,10 @@ pipeline {
|
||||
name: "BUILD_GFX908",
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx908 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX950",
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx950 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX12",
|
||||
defaultValue: true,
|
||||
@@ -1145,6 +1155,48 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run TILE_ENGINE_GEMM Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run TILE_ENGINE_GEMM Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make benchmark_gemm -j && \
|
||||
./bin/benchmark_gemm """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Run TILE_ENGINE_GEMM Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
|
||||
make benchmark_gemm -j && \
|
||||
./bin/benchmark_gemm """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage("Build CK and run Tests")
|
||||
{
|
||||
@@ -1188,7 +1240,7 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK for all gfx9 targets")
|
||||
stage("Build CK and run Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
@@ -1203,6 +1255,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1210,6 +1263,29 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on gfx950")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.BUILD_GFX950.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx950") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "rocm/composable_kernel-private:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on gfx908")
|
||||
{
|
||||
when {
|
||||
@@ -1223,6 +1299,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx908" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1243,6 +1320,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx90a" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1250,7 +1328,7 @@ pipeline {
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Build CK instances for different targets")
|
||||
stage("Build CK instances for all supported targets")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
@@ -1281,6 +1359,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx1030" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1301,6 +1380,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx1101" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
@@ -1321,6 +1401,7 @@ pipeline {
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx1201" \
|
||||
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
|
||||
@@ -48,6 +48,7 @@ rocm_install_targets(
|
||||
INCLUDE include
|
||||
)
|
||||
rocm_export_targets(
|
||||
TARGETS ck_host ck_headers
|
||||
EXPORT ck_host_targets
|
||||
NAMESPACE composable_kernel::
|
||||
)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core[api_reference]==1.18.4
|
||||
rocm-docs-core[api_reference]==1.20.0
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
|
||||
@@ -237,7 +237,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core[api-reference]==1.18.4
|
||||
rocm-docs-core[api-reference]==1.20.0
|
||||
# via -r requirements.in
|
||||
rpds-py==0.24.0
|
||||
# via
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
@@ -117,8 +117,50 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cu) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cu = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
@@ -134,36 +176,50 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return 'true'
|
||||
else:
|
||||
return f'{self.bool_expr}'
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f'({str(self)}) && ({str(other)})')
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
constraint : CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -220,17 +276,18 @@ class FmhaFwdApiTrait:
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -297,8 +354,8 @@ class FmhaFwdApiPool:
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
@@ -313,25 +370,27 @@ class FmhaFwdApiPool:
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
@@ -423,33 +482,21 @@ class FmhaFwdKernel:
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'128' : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
@@ -458,53 +505,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'128' : [FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),]
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
d = CustomFactory.get_hdim_tile_size_dict(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
tiles = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
|
||||
@@ -58,7 +58,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
@@ -94,7 +95,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -129,9 +130,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -160,11 +161,12 @@ class FmhaFwdApiTrait:
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
@@ -227,6 +229,7 @@ class FmhaFwdPipeline:
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -262,8 +265,12 @@ class FmhaFwdPipeline:
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_skip == 't' : n += '_skip'
|
||||
else: n += '_nskip'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
@@ -293,7 +300,7 @@ class FmhaFwdApiPool:
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
@@ -381,6 +388,7 @@ class FmhaFwdKernel:
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
@@ -419,7 +427,8 @@ class FmhaFwdKernel:
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@@ -453,36 +462,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -508,7 +517,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
@@ -532,6 +541,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
@@ -540,6 +550,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
@@ -565,6 +576,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
|
||||
@@ -169,6 +169,7 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
ck_tile::index_t min_seqlen_q;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
@@ -433,6 +434,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
@@ -837,7 +839,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
bool kPadDv_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -861,6 +864,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -995,6 +999,7 @@ struct fmha_fwd_traits
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
bool skip_min_seqlen_q = false;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -214,4 +214,15 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,4 +220,11 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -178,7 +178,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
float ave_time =
|
||||
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename Pipeline, ck_tile::TailNumber TN>
|
||||
void try_run(ck_tile::TailNumber tn)
|
||||
@@ -74,64 +75,102 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ave_time = ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
@@ -243,8 +282,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
@@ -345,7 +382,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
run_gemm_example(argc, argv);
|
||||
return !run_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -334,16 +334,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
|
||||
try
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string index_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
|
||||
bool r = true;
|
||||
if(weight_prec == "fp32" && index_prec == "int32")
|
||||
{
|
||||
r &= test_moe_sorting<float, ck_tile::index_t>(args);
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
|
||||
@@ -320,4 +320,15 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
|
||||
#include "run_batched_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_batched_gemm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,4 +319,15 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
constexpr bool Persistent = false;
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_grouped_gemm_example<Persistent>(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_basic.hpp"
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -115,9 +116,47 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time{0};
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
if(args.k_batch == 1)
|
||||
@@ -132,8 +171,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -177,4 +214,15 @@ int run_flatmm_example(int argc, char* argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
return !run_flatmm_example(argc, argv);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,4 +133,11 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -122,7 +122,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
|
||||
float ave_time =
|
||||
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
@@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
@@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
{
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y * Z;
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM =
|
||||
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN =
|
||||
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
wei_grid_desc,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * Z * X * Y;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;
|
||||
|
||||
@@ -351,6 +351,98 @@ struct Bilinear
|
||||
float beta_;
|
||||
};
|
||||
|
||||
struct AddClamp
|
||||
{
|
||||
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
const float a = x0 + x1;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double, double, double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
const double a = x0 + x1;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
const half_t a = x0 + x1;
|
||||
y = a > type_convert<half_t>(floor_)
|
||||
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
|
||||
: type_convert<half_t>(floor_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
const float a = x0 + x1;
|
||||
y = a > type_convert<half_t>(floor_)
|
||||
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
|
||||
: type_convert<half_t>(floor_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(floor_)
|
||||
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
|
||||
: type_convert<bhalf_t>(floor_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
|
||||
{
|
||||
const float a = type_convert<float>(x0) + type_convert<float>(x1);
|
||||
y = a > type_convert<bhalf_t>(floor_)
|
||||
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
|
||||
: type_convert<bhalf_t>(floor_);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
|
||||
{
|
||||
const int8_t a = x0 + x1;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
const int8_t a = x0 + x1;
|
||||
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
|
||||
};
|
||||
|
||||
const float floor_;
|
||||
const float ceil_;
|
||||
};
|
||||
|
||||
struct AddRelu
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -166,8 +166,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
@@ -365,8 +365,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
@@ -558,8 +558,8 @@ struct TransformConvBwdWeightToGemm
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * Z * X * Y;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
|
||||
@@ -346,8 +346,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmM = K * NumGroupsToMerge;
|
||||
const index_t GemmN = C * X * NumGroupsToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
@@ -534,8 +534,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmM = K * NumGroupsToMerge;
|
||||
const index_t GemmN = C * X * Y * NumGroupsToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
@@ -737,8 +737,8 @@ struct TransformConvBwdWeightToGemmV2
|
||||
const index_t GemmM = K * NumGroupsToMerge;
|
||||
const index_t GemmN = C * Z * X * Y * NumGroupsToMerge;
|
||||
|
||||
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
|
||||
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
|
||||
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
|
||||
|
||||
#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
|
||||
#define CHAR_BIT 8
|
||||
using int8_t = signed char;
|
||||
using uint8_t = unsigned char;
|
||||
using int16_t = signed short;
|
||||
|
||||
@@ -55,8 +55,8 @@
|
||||
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_base.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
|
||||
@@ -1437,8 +1437,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
static_assert(
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, int32_t>::value &&
|
||||
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
@@ -1561,6 +1561,24 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
thread_buffer<float, 8> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
|
||||
{
|
||||
@@ -1597,6 +1615,24 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
thread_buffer<float, 8> tmp;
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<0>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset,
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
tmp.template get_as<fp32x4_t>()(number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
static_cast<index_t>(coherence));
|
||||
|
||||
return bit_cast<rtn_type>(tmp);
|
||||
}
|
||||
}
|
||||
else // other datatype
|
||||
{
|
||||
|
||||
@@ -35,4 +35,7 @@
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/host/timer.hpp"
|
||||
#include "ck_tile/host/flush_icache.hpp"
|
||||
#include "ck_tile/host/rotating_buffers.hpp"
|
||||
|
||||
56
include/ck_tile/host/device_prop.hpp
Normal file
56
include/ck_tile/host/device_prop.hpp
Normal file
@@ -0,0 +1,56 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u)
|
||||
{
|
||||
return str.empty() ? h
|
||||
: fnv1a_hash(str.substr(1),
|
||||
(h ^ static_cast<unsigned char>(str.front())) * 16777619u);
|
||||
}
|
||||
inline std::string get_device_name()
|
||||
{
|
||||
hipDeviceProp_t props{};
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return std::string();
|
||||
}
|
||||
const std::string raw_name(props.gcnArchName);
|
||||
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
|
||||
switch(fnv1a_hash(name))
|
||||
{
|
||||
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
|
||||
case fnv1a_hash("Ellesmere"):
|
||||
case fnv1a_hash("Baffin"):
|
||||
case fnv1a_hash("RacerX"):
|
||||
case fnv1a_hash("Polaris10"):
|
||||
case fnv1a_hash("Polaris11"):
|
||||
case fnv1a_hash("Tonga"):
|
||||
case fnv1a_hash("Fiji"):
|
||||
case fnv1a_hash("gfx800"):
|
||||
case fnv1a_hash("gfx802"):
|
||||
case fnv1a_hash("gfx804"): return "gfx803";
|
||||
case fnv1a_hash("Vega10"):
|
||||
case fnv1a_hash("gfx901"): return "gfx900";
|
||||
case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030";
|
||||
default: return name;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
#endif
|
||||
30
include/ck_tile/host/flush_icache.hpp
Normal file
30
include/ck_tile/host/flush_icache.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
static __global__ void flush_cache()
|
||||
{
|
||||
asm __volatile__("s_icache_inv \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t"
|
||||
"s_nop 0 \n\t" ::
|
||||
:);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,6 +11,13 @@
|
||||
#include <cstddef>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#define LOW_CU_PROCESSORS 80
|
||||
#define HIGH_CU_PROCESSORS 228
|
||||
#define OPTIMAL_LATENCY_LOW_CU_PROCESSORS 0.005
|
||||
#define OPTIMAL_LATENCY_HIGH_CU_PROCESSORS 0.0015
|
||||
#define OPTIMAL_LATENCY_SAFE_MARGIN 0.01
|
||||
|
||||
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
@@ -81,6 +88,8 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
|
||||
template <typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
|
||||
{
|
||||
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
|
||||
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
@@ -88,7 +97,7 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
|
||||
}
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// warmup
|
||||
// Warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
@@ -114,4 +123,53 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PreprocessFunc, typename... Callables>
|
||||
CK_TILE_HOST float launch_kernel_preprocess(const stream_config& s,
|
||||
PreprocessFunc preprocess,
|
||||
Callables&&... callables)
|
||||
{
|
||||
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
|
||||
|
||||
if(!s.time_kernel_)
|
||||
{
|
||||
preprocess();
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto time_launches = [&](auto timer) {
|
||||
// Warmup
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
|
||||
timer.start(s.stream_id_);
|
||||
for(int i = 0; i < s.nrepeat_; i++)
|
||||
{
|
||||
preprocess();
|
||||
launch_and_check(s, std::forward<Callables>(callables)...);
|
||||
}
|
||||
timer.stop(s.stream_id_);
|
||||
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
|
||||
float preprocess_offset = (deviceProps.multiProcessorCount >= HIGH_CU_PROCESSORS)
|
||||
? OPTIMAL_LATENCY_HIGH_CU_PROCESSORS
|
||||
: (deviceProps.multiProcessorCount == LOW_CU_PROCESSORS)
|
||||
? OPTIMAL_LATENCY_LOW_CU_PROCESSORS
|
||||
: OPTIMAL_LATENCY_SAFE_MARGIN;
|
||||
return (timer.duration() - preprocess_offset * s.nrepeat_) / s.nrepeat_;
|
||||
};
|
||||
|
||||
if(s.is_gpu_timer_)
|
||||
{
|
||||
return time_launches(gpu_timer{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return time_launches(cpu_timer{});
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
102
include/ck_tile/host/rotating_buffers.hpp
Normal file
102
include/ck_tile/host/rotating_buffers.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename BDataType>
|
||||
struct RotatingMemWrapper
|
||||
{
|
||||
RotatingMemWrapper() = delete;
|
||||
RotatingMemWrapper(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
rotating_count(rotating_count_),
|
||||
size_a(size_a_),
|
||||
size_b(size_b_)
|
||||
{
|
||||
p_a_grids.push_back(a_ptr);
|
||||
p_b_grids.push_back(b_ptr);
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_a_grids.push_back(pADeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_b_grids.push_back(pBDeviceBuf);
|
||||
}
|
||||
}
|
||||
}
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
a_ptr = p_a_grids[idx];
|
||||
b_ptr = p_b_grids[idx];
|
||||
}
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapper() noexcept
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
a_ptr = p_a_grids[0];
|
||||
b_ptr = p_b_grids[0];
|
||||
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
std::size_t iter = 0;
|
||||
std::size_t rotating_count = 1;
|
||||
std::size_t size_a = 0;
|
||||
std::size_t size_b = 0;
|
||||
std::vector<const void*> p_a_grids;
|
||||
std::vector<const void*> p_b_grids;
|
||||
};
|
||||
inline void flush_icache()
|
||||
{
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
|
||||
|
||||
ck_tile::flush_cache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -30,5 +30,7 @@ struct stream_config
|
||||
int cold_niters_ = 3;
|
||||
int nrepeat_ = 10;
|
||||
bool is_gpu_timer_ = true; // keep compatible
|
||||
bool flush_cache_ = false;
|
||||
int rotating_count_ = 1;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -53,6 +53,8 @@ struct FmhaFwdKernel
|
||||
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
|
||||
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
|
||||
|
||||
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
@@ -257,6 +259,11 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdSkipMinSeqlenQKargs
|
||||
{
|
||||
ck_tile::index_t min_seqlen_q = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeKargs
|
||||
: FmhaFwdCommonKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
@@ -287,7 +294,8 @@ struct FmhaFwdKernel
|
||||
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -664,6 +672,7 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
@@ -698,6 +707,7 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for logits_soft_cap
|
||||
{}, // placeholder for min_seqlen_q
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -753,6 +763,10 @@ struct FmhaFwdKernel
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
kargs.min_seqlen_q = min_seqlen_q;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -969,7 +983,15 @@ struct FmhaFwdKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -989,7 +1011,15 @@ struct FmhaFwdKernel
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1053,6 +1083,14 @@ struct FmhaFwdKernel
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
if constexpr(kSkipMinSeqlenQ)
|
||||
{
|
||||
if(kargs.seqlen_q <= kargs.min_seqlen_q)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
|
||||
@@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
// assume that num_tile_n1 is always 1
|
||||
return ck_tile::make_tuple(
|
||||
(gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
@@ -53,6 +53,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
|
||||
@@ -19,7 +19,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
|
||||
struct TileFmhaTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
@@ -33,6 +34,7 @@ struct TileFmhaTraits
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
|
||||
@@ -120,6 +120,7 @@ using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
|
||||
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -25,7 +25,7 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -25,7 +25,7 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#include "grouped_convolution_forward_bias_relu_xdl.inc"
|
||||
#include "grouped_convolution_forward_bias_clamp_xdl.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
@@ -44,7 +44,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddRelu,
|
||||
ck::tensor_operation::element_wise::AddClamp,
|
||||
AComputeType,
|
||||
BComputeType>>
|
||||
{
|
||||
@@ -60,7 +60,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
OutDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddRelu,
|
||||
ck::tensor_operation::element_wise::AddClamp,
|
||||
AComputeType,
|
||||
BComputeType>;
|
||||
|
||||
@@ -80,23 +80,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -112,19 +112,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<AComputeType, ck::bhalf_t> &&
|
||||
is_same_v<BComputeType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -10,7 +10,7 @@ namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -22,9 +22,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -36,9 +36,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -50,9 +50,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -64,9 +64,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -78,9 +78,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -92,9 +92,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -106,9 +106,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -120,9 +120,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -134,9 +134,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -148,9 +148,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -162,9 +162,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -176,9 +176,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhw
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -190,9 +190,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -204,9 +204,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -218,9 +218,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -232,7 +232,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances);
|
||||
AddClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -104,17 +104,6 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(MIOPEN_REQ_LIBS_ONLY)
|
||||
message("Removing all sources that are not required for MIOpen")
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(source MATCHES "gemm" OR
|
||||
source MATCHES "mha" OR
|
||||
source MATCHES "contraction" OR
|
||||
source MATCHES "reduce")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
#message("remaining instances: ${ARGN}")
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
@@ -180,7 +169,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
target_compile_features(${INSTANCE_NAME} PUBLIC)
|
||||
|
||||
# flags to compress the library
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
|
||||
if(NOT DISABLE_OFFLOAD_COMPRESS AND NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
|
||||
#message("Adding --offload-compress flag for ${INSTANCE_NAME}")
|
||||
target_compile_options(${INSTANCE_NAME} PRIVATE --offload-compress)
|
||||
endif()
|
||||
@@ -294,6 +283,17 @@ FOREACH(subdir_path ${dir_list})
|
||||
message("Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
|
||||
if(MIOPEN_REQ_LIBS_ONLY)
|
||||
message("Removing all sources that are not required for MIOpen")
|
||||
if("${cmake_instance}" MATCHES "gemm" OR
|
||||
"${cmake_instance}" MATCHES "mha" OR
|
||||
"${cmake_instance}" MATCHES "contraction" OR
|
||||
"${cmake_instance}" MATCHES "reduce")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if((add_inst EQUAL 1))
|
||||
get_filename_component(target_dir ${subdir_path} NAME)
|
||||
add_subdirectory(${target_dir})
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
|
||||
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
|
||||
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp
|
||||
)
|
||||
@@ -10,7 +10,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -22,7 +22,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -57,7 +57,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -22,7 +22,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
@@ -42,7 +42,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
|
||||
@@ -52,7 +52,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -10,7 +10,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -22,7 +22,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -57,7 +57,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
@@ -31,7 +31,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
@@ -41,7 +41,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
|
||||
@@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
@@ -31,7 +31,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
@@ -41,7 +41,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
|
||||
NHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<2,
|
||||
@@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
|
||||
NHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
@@ -43,7 +43,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
@@ -54,7 +54,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
@@ -43,7 +43,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
|
||||
@@ -54,7 +54,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -10,7 +10,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
|
||||
void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -22,7 +22,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
{
|
||||
@@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
|
||||
NHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -59,7 +59,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -70,7 +70,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
|
||||
NHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# ONLY XDL_KERNELS
|
||||
add_instance_library(device_grouped_conv2d_fwd_bias_relu_instance
|
||||
xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
|
||||
|
||||
xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
|
||||
xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
|
||||
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp
|
||||
xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
|
||||
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
|
||||
|
||||
xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD})
|
||||
@@ -10,7 +10,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -22,7 +22,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
NDHWGC,
|
||||
@@ -41,7 +41,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
|
||||
NDHWGC,
|
||||
@@ -50,7 +50,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
|
||||
if(ck::get_device_name() != "gfx950")
|
||||
{
|
||||
@@ -63,7 +63,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
|
||||
@@ -73,7 +73,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
|
||||
@@ -83,7 +83,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
if(ck::get_device_name() == "gfx950")
|
||||
@@ -97,7 +97,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
|
||||
@@ -107,7 +107,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
|
||||
@@ -117,7 +117,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
@@ -31,7 +31,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
@@ -40,7 +40,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
|
||||
NDHWGC,
|
||||
@@ -49,7 +49,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
@@ -31,7 +31,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
NDHWGC,
|
||||
@@ -40,7 +40,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
|
||||
NDHWGK,
|
||||
ConvFwd1x1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_instances<3,
|
||||
NDHWGC,
|
||||
@@ -49,7 +49,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
|
||||
NDHWGK,
|
||||
ConvFwd1x1S1P0,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhw
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhw
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
ConvFwdDefault,
|
||||
Interwave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
@@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
ConvFwd1x1P0,
|
||||
Interwave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
@@ -52,7 +52,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
ConvFwd1x1S1P0,
|
||||
Interwave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
ConvFwdDefault,
|
||||
Intrawave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
@@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
ConvFwd1x1P0,
|
||||
Intrawave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
|
||||
NDHWGC,
|
||||
@@ -52,7 +52,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
|
||||
ConvFwd1x1S1P0,
|
||||
Intrawave,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3,
|
||||
@@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
|
||||
NDHWGK,
|
||||
ConvFwd3x3,
|
||||
Tuple<BF16>,
|
||||
AddRelu>{});
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -1,16 +0,0 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
|
||||
|
||||
xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
|
||||
xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
|
||||
|
||||
xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_bias_relu_instance ${GROUPED_CONV3D_FWD})
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
@@ -35,19 +35,22 @@ template <ck::index_t NDimSpatial,
|
||||
typename AComputeType = InDataType,
|
||||
typename BComputeType = AComputeType,
|
||||
typename IndexType = ck::index_t>
|
||||
bool profile_grouped_conv_fwd_bias_relu_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::AddClamp;
|
||||
|
||||
const float floor = 0.f;
|
||||
const float ceil = 256.f;
|
||||
|
||||
const auto in_element_op = InElementOp{};
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
const auto out_element_op = OutElementOp{floor, ceil};
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
@@ -251,7 +251,7 @@ add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_fwd)
|
||||
add_subdirectory(convnd_bwd_data)
|
||||
add_subdirectory(grouped_convnd_fwd)
|
||||
add_subdirectory(grouped_convnd_fwd_bias_relu)
|
||||
add_subdirectory(grouped_convnd_fwd_bias_clamp)
|
||||
add_subdirectory(grouped_convnd_bwd_weight)
|
||||
add_subdirectory(block_to_ctile_map)
|
||||
add_subdirectory(softmax)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx94/gfx95
|
||||
# Currently ck_tile_gemm is only built on gfx94/gfx95
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS "")
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
@@ -12,8 +12,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
|
||||
-enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
|
||||
@@ -25,4 +23,3 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
else()
|
||||
message("Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
4
test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt
Normal file
4
test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
|
||||
endif()
|
||||
@@ -7,11 +7,11 @@
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp"
|
||||
#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndFwd : public ::testing::Test
|
||||
@@ -32,16 +32,16 @@ class TestGroupedConvndFwd : public ::testing::Test
|
||||
bool pass = true;
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_relu_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType>(
|
||||
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType,
|
||||
IndexType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
@@ -1,4 +0,0 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_grouped_convnd_fwd_bias_relu test_grouped_convnd_fwd_bias_relu.cpp)
|
||||
target_link_libraries(test_grouped_convnd_fwd_bias_relu PRIVATE utility device_grouped_conv2d_fwd_bias_relu_instance device_grouped_conv3d_fwd_bias_relu_instance)
|
||||
endif()
|
||||
@@ -1,43 +1,60 @@
|
||||
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
message( FATAL_ERROR "Fail to list kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
|
||||
|
||||
set(GEMM_CODEGEN_CPP_FILES "")
|
||||
set(GEMM_CODEGEN_HPP_FILES "")
|
||||
|
||||
foreach(blob ${GEMM_CODEGEN_BLOBS})
|
||||
string(STRIP "${blob}" stripped_blob)
|
||||
|
||||
if(stripped_blob MATCHES "\\.cpp$")
|
||||
list(APPEND GEMM_CODEGEN_CPP_FILES "${stripped_blob}")
|
||||
elseif(stripped_blob MATCHES "\\.hpp$")
|
||||
list(APPEND GEMM_CODEGEN_HPP_FILES "${stripped_blob}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--gen_blobs
|
||||
DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt
|
||||
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
|
||||
)
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
|
||||
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")
|
||||
add_library(gemm_template_instances OBJECT EXCLUDE_FROM_ALL ${GEMM_CODEGEN_CPP_FILES})
|
||||
# Explicitly set LINKER_LANGUAGE to avoid build config failures with Ninja.
|
||||
set_target_properties(gemm_template_instances PROPERTIES LINKER_LANGUAGE CXX)
|
||||
target_include_directories(gemm_template_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(gemm_template_instances PRIVATE ${GEMM_CODEGEN_HPP_FILES})
|
||||
|
||||
set(BENCHMARK_GEMM_EXECUTABLE "benchmark_gemm")
|
||||
message("adding example ${BENCHMARK_GEMM_EXECUTABLE}")
|
||||
|
||||
# use build as include directory
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp)
|
||||
target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS})
|
||||
|
||||
add_library(gemm_host_api INTERFACE EXCLUDE_FROM_ALL)
|
||||
target_include_directories(gemm_host_api INTERFACE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(gemm_host_api INTERFACE ${GEMM_CODEGEN_HPP_FILES} gemm_host_api.hpp)
|
||||
target_link_libraries(gemm_host_api INTERFACE gemm_template_instances)
|
||||
|
||||
add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp)
|
||||
target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp)
|
||||
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api)
|
||||
|
||||
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
|
||||
|
||||
@@ -46,6 +63,6 @@ list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS
|
||||
-Wno-float-equal
|
||||
--offload-compress)
|
||||
|
||||
target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
|
||||
target_compile_options(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
|
||||
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
@@ -4,10 +4,11 @@ CK Tile Engine GEMM is used to generate and run GEMM kernels with different comb
|
||||
|
||||
# Kernel Configurations
|
||||
|
||||
Kernel parameters are specified in the `instance_combination.json` file, including matrix layouts, data types, padding settings, pipelines, schedulers, epilogues, and numerical values for tile and warp sizes.
|
||||
User can provide kernel configuration such as tile size, warp size, padding, pipeline, scheduler and epilogue in the config file with limited values. For reference please see `./configs/user_provided_config.json`.
|
||||
|
||||
Given a valid set of values, tile_engine_gemm will automatically iterate over all possible combinations of BlockTile and WarpTile sizes, as well as the specified pipelines, schedulers, and epilogues from `./configs/instance_combination.json`, and build the corresponding kernels.
|
||||
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
|
||||
|
||||
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
|
||||
|
||||
## Build Instructions
|
||||
``` bash
|
||||
@@ -16,41 +17,45 @@ mkdir build && cd build
|
||||
# build composable kernel
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # replace <arch> with the appropriate architecture (example gfx942) or leave blank
|
||||
# generate the executable
|
||||
make tile_engine_gemm -j
|
||||
make benchmark_gemm -j
|
||||
```
|
||||
`tile_engine_gemm` will be located in the `./bin/` directory.
|
||||
`benchmark_gemm` will be located in the `./bin/` directory.
|
||||
|
||||
`benchmark_gemm` must be rebuilt everytime if configuration file is modified.
|
||||
|
||||
_`tile_engine_gemm` must be rebuilt everytime `instance_combination.json` is modified._
|
||||
``` bash
|
||||
rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild
|
||||
rm -rf tile_engine/ && make benchmark_gemm -j # rebuild
|
||||
```
|
||||
|
||||
## tile_engine_gemm inputs
|
||||
## benchmark_gemm inputs
|
||||
```
|
||||
-m The value for m dimension. Default is 3840.
|
||||
-n The value for n dimension. Default is 4096.
|
||||
-k The value for k dimension. Default is 2048.
|
||||
-stride_a The stride value for tensor A. Default is 0.
|
||||
-stride_b The stride value for tensor B. Default is 0.
|
||||
-stride_c The stride value for tensor C Default is 0.
|
||||
-split_k The split value for k dimension. Default is 1.
|
||||
-v The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 2, validation on GPU.
|
||||
-log Wether output kernel instance information or not. Possible values are true or false. Default is false.
|
||||
-warmup The number of iterations before benchmark the kernel. Default is 50.
|
||||
-repeat The number of iterations to benchmark the kernel. Default is 100.
|
||||
-timer Whether if the timer is gpu timer or not. Possible values are true or false. Default is true.
|
||||
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
|
||||
-flush_cache To flush cache in between different runs.Possible values are true or false. Default is false.
|
||||
-rotating_count count to flush cache. Default is 5.
|
||||
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
|
||||
-csv_filename The filename of benchmark result. Default is gemm_kernel.
|
||||
-structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false.
|
||||
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
|
||||
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
|
||||
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.
|
||||
-pad_n Whether pad or not in n direction. Possible values are true or false. Default is false.
|
||||
-pad_k Whether pad or not in k direction. Possible values are true or false. Default is false.
|
||||
|
||||
-m m dimension (default:3840)
|
||||
-n n dimension (default:4096)
|
||||
-k k dimension (default:2048)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-split_k SplitK value (default:1)
|
||||
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
|
||||
-warmup Number of iterations before benchmark the kernel (default:50)
|
||||
-repeat Number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
|
||||
-structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0)
|
||||
-pipeline possible values are: compv3, compv4, mem (default:compv3)
|
||||
-scheduler possible values are: intrawave, interwave (default:intrawave)
|
||||
-epilogue possible values are: cshuffle, default (default:cshuffle)
|
||||
-pad_m Pad in m direction - true/false (default:false)
|
||||
-pad_n Pad in n direction - true/false (default:false)
|
||||
-pad_k Pad in k direction - true/false (default:false)
|
||||
|
||||
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json
|
||||
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json
|
||||
```
|
||||
Note: In `./configs/instance_combination.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
|
||||
Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
|
||||
|
||||
## Example
|
||||
|
||||
@@ -86,7 +91,7 @@ The following JSON file specifies parameters used to generate and build GEMM ker
|
||||
|
||||
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
|
||||
``` bash
|
||||
./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
|
||||
./bin/benchmark_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
|
||||
```
|
||||
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.
|
||||
|
||||
|
||||
68
tile_engine/ops/gemm/benchmark_gemm.cpp
Normal file
68
tile_engine/ops/gemm/benchmark_gemm.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
|
||||
#include "gemm_profiler.hpp"
|
||||
#include "benchmark_gemm.hpp"
|
||||
|
||||
void benchmark_gemm(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
GemmProblem gemm_problem{arg_parser.get_int("split_k"),
|
||||
arg_parser.get_int("m"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("stride_a"),
|
||||
arg_parser.get_int("stride_b"),
|
||||
arg_parser.get_int("stride_c"),
|
||||
DataTypeTraits<ADataType>::name,
|
||||
DataTypeTraits<BDataType>::name,
|
||||
DataTypeTraits<AccDataType>::name,
|
||||
DataTypeTraits<CDataType>::name,
|
||||
ALayout::name,
|
||||
BLayout::name,
|
||||
CLayout::name,
|
||||
arg_parser.get_bool("structured_sparsity")};
|
||||
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count")};
|
||||
|
||||
auto& profiler = GemmProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
auto kernel_func = get_kernel_func_by_trait(arg_parser);
|
||||
profiler.benchmark(gemm_problem, kernel_func);
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
benchmark_gemm(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
235
tile_engine/ops/gemm/benchmark_gemm.hpp
Normal file
235
tile_engine/ops/gemm/benchmark_gemm.hpp
Normal file
@@ -0,0 +1,235 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "gemm_host_api.hpp"
|
||||
|
||||
enum class Metric
|
||||
{
|
||||
LATENCY = 0,
|
||||
TFLOPS = 1,
|
||||
BANDWIDTH = 2
|
||||
};
|
||||
|
||||
inline constexpr auto get_metric_name(Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return "latency";
|
||||
case Metric::TFLOPS: return "tflops";
|
||||
case Metric::BANDWIDTH: return "bandwidth";
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
struct GemmProblem
|
||||
{
|
||||
int split_k_;
|
||||
int m_, n_, k_;
|
||||
int stride_a_, stride_b_, stride_c_;
|
||||
|
||||
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
|
||||
std::string layout_a_, layout_b_, layout_c_;
|
||||
|
||||
bool structured_sparsity_;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
||||
<< " \"m\":" << problem.m_ << ",\n"
|
||||
<< " \"n\":" << problem.n_ << ",\n"
|
||||
<< " \"k\":" << problem.k_ << ",\n"
|
||||
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
||||
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
||||
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
||||
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
||||
<< " \"layout_c\":\"" << problem.layout_c_ << "\"\n"
|
||||
<< " \"structured_sparsity\":\"" << problem.structured_sparsity_ << "\"\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct PerformanceResult
|
||||
{
|
||||
double latency_;
|
||||
double tflops_;
|
||||
double bandwidth_;
|
||||
|
||||
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return a.latency_ < b.latency_;
|
||||
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
|
||||
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
|
||||
<< ",\n"
|
||||
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
|
||||
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelInstance
|
||||
{
|
||||
std::string name_;
|
||||
GemmProblem problem_;
|
||||
PerformanceResult perf_result_;
|
||||
|
||||
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
{
|
||||
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"name\": \""
|
||||
<< "{\n"
|
||||
<< obj.name_ << "\n}"
|
||||
<< "\",\n"
|
||||
<< " \"problem\": \"" << obj.problem_ << "\",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct Setting
|
||||
{
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
int verify_;
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
{
|
||||
std::ifstream version_file("/opt/rocm/.info/version");
|
||||
if(version_file.is_open())
|
||||
{
|
||||
std::string version;
|
||||
std::getline(version_file, version);
|
||||
return version;
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
bool compare(ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
|
||||
void gemm_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C)
|
||||
{
|
||||
if(verify == 1)
|
||||
{
|
||||
c_m_n_host_result.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_result);
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
|
||||
c_m_n_host_result.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
|
||||
}
|
||||
}
|
||||
239
tile_engine/ops/gemm/codegen_utils.py
Normal file
239
tile_engine/ops/gemm/codegen_utils.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Mappings and utility functions for kernel code generation.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
DATA_TYPE_MAP = {'fp32': 'float',
|
||||
'fp16': 'ck_tile::half_t',
|
||||
'bf16': 'ck_tile::bf16_t',
|
||||
'int8': 'ck_tile::int8_t',
|
||||
'fp8': 'ck_tile::fp8_t',
|
||||
'bf8': 'ck_tile::bf8_t',
|
||||
'int4': 'ck_tile::pk_int4_t'
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {'r': 'ck_tile::tensor_layout::gemm::RowMajor',
|
||||
'c': 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
true,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
WarpM,
|
||||
WarpN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
"""
|
||||
HOT_LOOP_FALSE = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
|
||||
}
|
||||
"""
|
||||
RUN_MEM = """
|
||||
// Handle One and Full cases directly
|
||||
if (tail_num == ck_tile::TailNumber::One) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
} else if (tail_num == ck_tile::TailNumber::Full) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
([&]{
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > static_cast<int>(decltype(TNs)::value)) {
|
||||
if(tail_num == decltype(TNs)::value) {
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, decltype(TNs)::value>{});
|
||||
}
|
||||
}
|
||||
}(), ...);
|
||||
};
|
||||
|
||||
check_tail(
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
|
||||
);
|
||||
"""
|
||||
|
||||
RUN_COMPV3 = """
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
|
||||
}
|
||||
"""
|
||||
|
||||
RUN_COMPV4 = """
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
PIPELINE_MAP = {'mem': ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
|
||||
'compv3': ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
|
||||
'compv4': ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
|
||||
|
||||
SCHEDULER_MAP = {'interwave': 'ck_tile::GemmPipelineScheduler::Interwave',
|
||||
'intrawave': 'ck_tile::GemmPipelineScheduler::Intrawave'}
|
||||
|
||||
EPILOGUE_MAP = {'default': DEFAULT_EPILOGUE,
|
||||
'cshuffle': CSHUFFLE_EPILOGUE}
|
||||
|
||||
HOT_LOOP_TRUE = {'mem': RUN_MEM,
|
||||
'compv3': RUN_COMPV3,
|
||||
'compv4': RUN_COMPV4}
|
||||
|
||||
|
||||
def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
|
||||
|
||||
|
||||
# To Do: add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
|
||||
},
|
||||
"gfx942": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
|
||||
},
|
||||
"gfx950": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
}
|
||||
|
||||
# To Do: remove some unsupported combinations
|
||||
trait_unsupported_combinations = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave")
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type in {'fp16', 'bf16'}:
|
||||
return 2
|
||||
elif data_type in {'int8', 'fp8', 'bf8'}:
|
||||
return 1
|
||||
elif data_type == 'int4':
|
||||
return 0.5
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r'Name:\s*(gfx\d+\w*)')
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"],
|
||||
text=True,
|
||||
stderr=subprocess.PIPE,
|
||||
timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
print("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
print(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
130
tile_engine/ops/gemm/configs/default_config.json
Normal file
130
tile_engine/ops/gemm/configs/default_config.json
Normal file
@@ -0,0 +1,130 @@
|
||||
{
|
||||
"problem": {
|
||||
"layout_a": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": [
|
||||
"c"
|
||||
]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"datatype_a": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
},
|
||||
"datatype_b": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
},
|
||||
"datatype_c": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
}
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 512,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": []
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 512,
|
||||
"min": 64,
|
||||
"step": 32,
|
||||
"exclude": []
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 512,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": []
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv4",
|
||||
"compv3",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
{
|
||||
"architecture": {
|
||||
"values": ["gfx90a"]
|
||||
},
|
||||
"layout_a": {
|
||||
"values": ["r"]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": ["c"]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": ["r"]
|
||||
},
|
||||
"datatype": {
|
||||
"values": ["fp16"]
|
||||
},
|
||||
"tile_m": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [2]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [2]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [1]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [32]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [16]
|
||||
},
|
||||
"kPadM": {
|
||||
"values": [false]
|
||||
},
|
||||
"kPadN": {
|
||||
"values": [false]
|
||||
},
|
||||
"kPadK": {
|
||||
"values": [false]
|
||||
},
|
||||
"pipeline": {
|
||||
"values": ["compv3", "compv4", "mem"]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": ["intrawave", "interwave"]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": ["default", "cshuffle"]
|
||||
}
|
||||
}
|
||||
116
tile_engine/ops/gemm/configs/user_provided_config.json
Normal file
116
tile_engine/ops/gemm/configs/user_provided_config.json
Normal file
@@ -0,0 +1,116 @@
|
||||
{
|
||||
"problem": {
|
||||
"layout_a": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"layout_b": {
|
||||
"values": [
|
||||
"c"
|
||||
]
|
||||
},
|
||||
"layout_c": {
|
||||
"values": [
|
||||
"r"
|
||||
]
|
||||
},
|
||||
"datatype_a": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
},
|
||||
"datatype_b": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
},
|
||||
"datatype_c": {
|
||||
"values": [
|
||||
"fp16"
|
||||
]
|
||||
}
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
#include "gemm_dispatcher.hpp"
|
||||
#include "gemm_host_api.hpp"
|
||||
|
||||
void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
int verify,
|
||||
bool structured_sparsity,
|
||||
KernelTraits& trait,
|
||||
ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& stream)
|
||||
{
|
||||
return GemmDispatcher::dispatch(c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
structured_sparsity,
|
||||
trait,
|
||||
args,
|
||||
stream);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
const ALayout a_layout = ALayout{};
|
||||
const BLayout b_layout = BLayout{};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
int verify = arg_parser.get_int("v");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(structured_sparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs gemm_args;
|
||||
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.k_batch = kbatch;
|
||||
gemm_args.M = M;
|
||||
gemm_args.N = N;
|
||||
gemm_args.K = K;
|
||||
gemm_args.stride_A = stride_A;
|
||||
gemm_args.stride_B = stride_B;
|
||||
gemm_args.stride_C = stride_C;
|
||||
|
||||
KernelTraits trait;
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.kPadM = arg_parser.get_bool("pad_m");
|
||||
trait.kPadN = arg_parser.get_bool("pad_n");
|
||||
trait.kPadK = arg_parser.get_bool("pad_k");
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name << std::endl;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(verify)
|
||||
{
|
||||
gemm_host_reference<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(verify,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
}
|
||||
|
||||
gemm_kernel_launch(c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
verify,
|
||||
structured_sparsity,
|
||||
trait,
|
||||
gemm_args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
218
tile_engine/ops/gemm/gemm_host_api.hpp
Executable file → Normal file
218
tile_engine/ops/gemm/gemm_host_api.hpp
Executable file → Normal file
@@ -1,16 +1,15 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_dispatcher.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
@@ -57,24 +56,6 @@ struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a
|
||||
/// specific kernel instance based on the provided settings.
|
||||
struct KernelTraits
|
||||
{
|
||||
/// @brief The name of the pipeline.
|
||||
std::string pipeline;
|
||||
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
|
||||
std::string scheduler;
|
||||
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
|
||||
std::string epilogue;
|
||||
/// @brief Indicates whether padding is applied to the M dimension.
|
||||
bool kPadM;
|
||||
/// @brief Indicates whether padding is applied to the N dimension.
|
||||
bool kPadN;
|
||||
/// @brief Indicates whether padding is applied to the K dimension.
|
||||
bool kPadK;
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
@@ -82,49 +63,76 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("structured_sparsity", "0", "0:false, 1:true")
|
||||
.insert("pipeline", "compv3", "compv3, compv4, mem")
|
||||
.insert("scheduler", "intrawave", "intrawave, interwave")
|
||||
.insert("epilogue", "cshuffle", "cshuffle, default")
|
||||
.insert("pad_m", "false", "true, false")
|
||||
.insert("pad_n", "false", "true, false")
|
||||
.insert("pad_k", "false", "true, false");
|
||||
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
|
||||
.insert("n", "4096", "The value for n dimension. Default is 4096.")
|
||||
.insert("k", "2048", "The value for k dimension. Default is 2048.")
|
||||
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
|
||||
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
|
||||
.insert("stride_c", "0", "The stride value for tensor C Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("verify",
|
||||
"2",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
"for validation on GPU. Default is 2, validation on GPU.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Wether output kernel instance information or not. Possible values are true or "
|
||||
"false. Default is false")
|
||||
.insert(
|
||||
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
|
||||
.insert(
|
||||
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Whether if the timer is gpu timer or not. Possible values are false or true. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"false",
|
||||
"To flush cache, possible values are true or false. "
|
||||
"Default is false.")
|
||||
.insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth. Default is 0, latency.")
|
||||
.insert("csv_filename",
|
||||
"gemm_kernel",
|
||||
"The filename of benchmark result. Default is gemm_kernel.")
|
||||
.insert("structured_sparsity",
|
||||
"false",
|
||||
"Whether use sparsity kernel or not. Possible values are true or false. Default is "
|
||||
"false")
|
||||
.insert(
|
||||
"pipeline",
|
||||
"compv3",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
|
||||
.insert("scheduler",
|
||||
"intrawave",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
|
||||
"compv3.")
|
||||
.insert(
|
||||
"epilogue",
|
||||
"cshuffle",
|
||||
"The type of epilogue. Possible values are cshuffle or default. Default is csshuffle.")
|
||||
.insert("pad_m",
|
||||
"false",
|
||||
"Whether pad or not in m direction. Possible values are true or false. Default is "
|
||||
"false.")
|
||||
.insert("pad_n",
|
||||
"false",
|
||||
"Whether pad or not in n direction. Possible values are true or false. Default is "
|
||||
"false.")
|
||||
.insert("pad_k",
|
||||
"false",
|
||||
"Whether pad or not in k direction. Possible values are true or false. Default is "
|
||||
"false.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -185,79 +193,17 @@ void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
void compare(ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
KernelTraits trait;
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.pad_m = arg_parser.get_bool("pad_m");
|
||||
trait.pad_n = arg_parser.get_bool("pad_n");
|
||||
trait.pad_k = arg_parser.get_bool("pad_k");
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void gemm_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C)
|
||||
{
|
||||
if(verify == 1)
|
||||
{
|
||||
c_m_n_host_result.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_result);
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
|
||||
c_m_n_host_result.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
|
||||
}
|
||||
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
|
||||
|
||||
return GemmDispatcher::dispatch(structured_sparsity, trait);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
262
tile_engine/ops/gemm/gemm_profiler.hpp
Normal file
262
tile_engine/ops/gemm/gemm_profiler.hpp
Normal file
@@ -0,0 +1,262 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm.hpp"
|
||||
|
||||
class GemmProfiler
|
||||
{
|
||||
public:
|
||||
static GemmProfiler& instance(Setting setting)
|
||||
{
|
||||
static GemmProfiler instance{setting};
|
||||
return instance;
|
||||
}
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
gemm_problem.stride_a_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a));
|
||||
gemm_problem.stride_b_ = ck_tile::get_default_stride(
|
||||
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b));
|
||||
gemm_problem.stride_c_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
if(setting_.init_method_ == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else if(setting_.init_method_ == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(setting_.init_method_ == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(gemm_problem.structured_sparsity_)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs gemm_args;
|
||||
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
gemm_args.k_batch = gemm_problem.split_k_;
|
||||
gemm_args.M = gemm_problem.m_;
|
||||
gemm_args.N = gemm_problem.n_;
|
||||
gemm_args.K = gemm_problem.k_;
|
||||
gemm_args.stride_A = gemm_problem.stride_a_;
|
||||
gemm_args.stride_B = gemm_problem.stride_b_;
|
||||
gemm_args.stride_C = gemm_problem.stride_c_;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
if(setting_.verify_)
|
||||
{
|
||||
gemm_host_reference(setting_.verify_,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
gemm_problem.stride_c_);
|
||||
}
|
||||
|
||||
for(auto& callable : callables)
|
||||
{
|
||||
auto kernel_run_result = callable(gemm_args,
|
||||
ck_tile::stream_config{nullptr,
|
||||
true,
|
||||
setting_.log_,
|
||||
setting_.n_warmup_,
|
||||
setting_.n_repeat_,
|
||||
setting_.is_gpu_timer_,
|
||||
setting_.flush_cache_,
|
||||
setting_.rotating_count_});
|
||||
process_result(gemm_problem,
|
||||
c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
kernel_run_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GemmProblem& gemm_problem,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
|
||||
KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
// compute performance metric
|
||||
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
|
||||
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
|
||||
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
|
||||
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
|
||||
|
||||
// update
|
||||
kernel_instance.perf_result_.latency_ = avg_time;
|
||||
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
||||
|
||||
if(setting_.log_ > 0)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
// verify result
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
// clear tensor
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
{
|
||||
if(kernel_instances_.empty())
|
||||
throw std::runtime_error("Empty instances");
|
||||
|
||||
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
||||
kernel_instances_.end(),
|
||||
[metric](const auto& a, const auto& b) {
|
||||
return PerformanceResult::compare(
|
||||
b.perf_result_, a.perf_result_, metric);
|
||||
});
|
||||
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "The best kernel instance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
{
|
||||
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(file.tellp() == 0)
|
||||
{
|
||||
file << "rocm_version,device_name,"
|
||||
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
|
||||
<< "dtype_a,dtype_b,dtype_acc,dtype_c,"
|
||||
<< "layout_a,layout_b,layout_c,"
|
||||
<< "structured_sparsity,"
|
||||
<< "name,"
|
||||
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
||||
}
|
||||
|
||||
const auto& problem = kernel_instance.problem_;
|
||||
const auto& name = kernel_instance.name_;
|
||||
const auto& perf = kernel_instance.perf_result_;
|
||||
|
||||
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
||||
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
|
||||
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
||||
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
|
||||
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
|
||||
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
|
||||
<< "," << problem.structured_sparsity_ << "," << name << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
||||
<< "\n";
|
||||
|
||||
if(!file)
|
||||
{
|
||||
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
||||
|
||||
private:
|
||||
~GemmProfiler() { kernel_instances_.clear(); }
|
||||
GemmProfiler(Setting setting) : setting_(setting) {}
|
||||
|
||||
Setting setting_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
202
tile_engine/ops/gemm/json_config.py
Normal file
202
tile_engine/ops/gemm/json_config.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Handles loading, parsing, and validation of JSON configuration parameters.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Tuple, Type, Dict
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumConfigParam:
|
||||
"""Represents an enumeration-type configuration parameter"""
|
||||
values: List[Union[int, str, bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConfigParam:
|
||||
"""Represents a numeric range-type configuration parameter"""
|
||||
min: int
|
||||
max: int
|
||||
step: int
|
||||
exclude: Optional[List[int]]
|
||||
|
||||
def generate_candidates(self) -> List[int]:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
|
||||
if self.min > self.max:
|
||||
raise ValueError(
|
||||
f"Invalid range: min({self.min}) > max({self.max})"
|
||||
)
|
||||
if self.step <= 0:
|
||||
raise ValueError(
|
||||
f"Step must be positive, got {self.step}"
|
||||
)
|
||||
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if hasattr(self, 'exclude') and self.exclude:
|
||||
if not isinstance(self.exclude, list):
|
||||
raise TypeError("exclude must be list type")
|
||||
exclude_set = set(self.exclude)
|
||||
candidates = [x for x in candidates if x not in exclude_set]
|
||||
|
||||
if not candidates:
|
||||
raise ValueError(
|
||||
f"No valid candidates for range [{self.min}-{self.max}] "
|
||||
f"with step {self.step} and excludes {self.exclude}"
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemConfig:
|
||||
"""configuration class for problem parameter."""
|
||||
datatypes: Tuple[EnumConfigParam, ...]
|
||||
layouts: Tuple[EnumConfigParam, ...]
|
||||
|
||||
@property
|
||||
def datatype_map(self) -> Dict[str, str]:
|
||||
"""Get datatype as a key-value map."""
|
||||
return {
|
||||
'matrix_a': self.datatypes[0].values[0],
|
||||
'matrix_b': self.datatypes[1].values[0],
|
||||
'matrix_c': self.datatypes[2].values[0]
|
||||
}
|
||||
|
||||
@property
|
||||
def layout_map(self) -> Dict[str, str]:
|
||||
"""Get layout as a key-value map."""
|
||||
return {
|
||||
'matrix_a': self.layouts[0].values[0],
|
||||
'matrix_b': self.layouts[1].values[0],
|
||||
'matrix_c': self.layouts[2].values[0]
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Configuration class for tile parameter."""
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Configuration class for kernel traits."""
|
||||
pipeline: EnumConfigParam
|
||||
scheduler: EnumConfigParam
|
||||
epilogue: EnumConfigParam
|
||||
pad_m: EnumConfigParam
|
||||
pad_n: EnumConfigParam
|
||||
pad_k: EnumConfigParam
|
||||
|
||||
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
"""Main configuration class for GEMM operations """
|
||||
problem: ProblemConfig
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
|
||||
"""JSON configuration loader with validation controls"""
|
||||
config_path = Path(filepath)
|
||||
|
||||
try:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
|
||||
with config_path.open('r') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Parse problem config
|
||||
problem = ProblemConfig(
|
||||
datatypes=(
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_a']['values']),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_b']['values']),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_c']['values'])
|
||||
),
|
||||
layouts=(
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_a']['values']),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_b']['values']),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_c']['values'])
|
||||
)
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
def create_param(param_dict):
|
||||
if 'values' in param_dict:
|
||||
return EnumConfigParam(values=param_dict['values'])
|
||||
else:
|
||||
return RangeConfigParam(
|
||||
min=param_dict['min'],
|
||||
max=param_dict['max'],
|
||||
step=param_dict['step'],
|
||||
exclude=param_dict.get('exclude', [])
|
||||
)
|
||||
|
||||
tile_config = TileConfig(
|
||||
tile_m=create_param(config_dict['tile_config']['tile_m']),
|
||||
tile_n=create_param(config_dict['tile_config']['tile_n']),
|
||||
tile_k=create_param(config_dict['tile_config']['tile_k']),
|
||||
warp_m=create_param(config_dict['tile_config']['warp_m']),
|
||||
warp_n=create_param(config_dict['tile_config']['warp_n']),
|
||||
warp_k=create_param(config_dict['tile_config']['warp_k']),
|
||||
warp_tile_m=create_param(
|
||||
config_dict['tile_config']['warp_tile_m']),
|
||||
warp_tile_n=create_param(
|
||||
config_dict['tile_config']['warp_tile_n']),
|
||||
warp_tile_k=create_param(
|
||||
config_dict['tile_config']['warp_tile_k'])
|
||||
)
|
||||
|
||||
# Parse trait config
|
||||
trait_config = TraitConfig(
|
||||
pipeline=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pipeline']['values']),
|
||||
scheduler=EnumConfigParam(
|
||||
values=config_dict['trait_config']['scheduler']['values']),
|
||||
epilogue=EnumConfigParam(
|
||||
values=config_dict['trait_config']['epilogue']['values']),
|
||||
pad_m=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_m']['values']),
|
||||
pad_n=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_n']['values']),
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_k']['values'])
|
||||
)
|
||||
|
||||
return cls(
|
||||
problem=problem,
|
||||
tile_config=tile_config,
|
||||
trait_config=trait_config
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {str(e)}")
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Missing required configuration field: {str(e)}")
|
||||
Reference in New Issue
Block a user