Merge remote-tracking branch 'origin/develop' into gfx950-mxfp4

This commit is contained in:
Ding, Yi
2025-06-03 05:38:56 +00:00
85 changed files with 3218 additions and 1540 deletions

View File

@@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added logit soft-capping support for fMHA forward kernels.
* Added benchmarking support for tile engine GEMM.
### Optimized

View File

@@ -26,6 +26,10 @@ set(version 1.1.0)
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
include(CTest)
option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
# CK Codegen requires dataclass which is added in Python 3.7
# Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04
@@ -390,146 +394,152 @@ else()
add_compile_definitions(__HIP_PLATFORM_HCC__=1)
endif()
## tidy
include(EnableCompilerWarnings)
## tidy
set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
# Enable tidy on hip
elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU")
set(CK_TIDY_ERRORS ALL)
set(CK_TIDY_ERRORS ALL)
endif()
include(ClangTidy)
enable_clang_tidy(
CHECKS
*
-abseil-*
-android-cloexec-fopen
# Yea we shouldn't be using rand()
-cert-msc30-c
-bugprone-exception-escape
-bugprone-macro-parentheses
-cert-env33-c
-cert-msc32-c
-cert-msc50-cpp
-cert-msc51-cpp
-cert-dcl37-c
-cert-dcl51-cpp
-clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-optin.performance.Padding
-clang-diagnostic-deprecated-declarations
-clang-diagnostic-extern-c-compat
-clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-avoid-c-arrays
-cppcoreguidelines-avoid-magic-numbers
-cppcoreguidelines-explicit-virtual-functions
-cppcoreguidelines-init-variables
-cppcoreguidelines-macro-usage
-cppcoreguidelines-non-private-member-variables-in-classes
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
-cppcoreguidelines-pro-bounds-constant-array-index
-cppcoreguidelines-pro-bounds-pointer-arithmetic
-cppcoreguidelines-pro-type-member-init
-cppcoreguidelines-pro-type-reinterpret-cast
-cppcoreguidelines-pro-type-union-access
-cppcoreguidelines-pro-type-vararg
-cppcoreguidelines-special-member-functions
-fuchsia-*
-google-explicit-constructor
-google-readability-braces-around-statements
-google-readability-todo
-google-runtime-int
-google-runtime-references
-hicpp-vararg
-hicpp-braces-around-statements
-hicpp-explicit-conversions
-hicpp-named-parameter
-hicpp-no-array-decay
# We really shouldn't use bitwise operators with signed integers, but
# opencl leaves us no choice
-hicpp-avoid-c-arrays
-hicpp-signed-bitwise
-hicpp-special-member-functions
-hicpp-uppercase-literal-suffix
-hicpp-use-auto
-hicpp-use-equals-default
-hicpp-use-override
-llvm-header-guard
-llvm-include-order
#-llvmlibc-*
-llvmlibc-restrict-system-libc-headers
-llvmlibc-callee-namespace
-llvmlibc-implementation-in-namespace
-llvm-else-after-return
-llvm-qualified-auto
-misc-misplaced-const
-misc-non-private-member-variables-in-classes
-misc-no-recursion
-modernize-avoid-bind
-modernize-avoid-c-arrays
-modernize-pass-by-value
-modernize-use-auto
-modernize-use-default-member-init
-modernize-use-equals-default
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-performance-unnecessary-value-param
-readability-braces-around-statements
-readability-else-after-return
# we are not ready to use it, but very useful
-readability-function-cognitive-complexity
-readability-isolate-declaration
-readability-magic-numbers
-readability-named-parameter
-readability-uppercase-literal-suffix
-readability-convert-member-functions-to-static
-readability-qualified-auto
-readability-redundant-string-init
# too many narrowing conversions in our code
-bugprone-narrowing-conversions
-cppcoreguidelines-narrowing-conversions
-altera-struct-pack-align
-cppcoreguidelines-prefer-member-initializer
${CK_TIDY_CHECKS}
${CK_TIDY_ERRORS}
HEADER_FILTER
"\.hpp$"
EXTRA_ARGS
-DCK_USE_CLANG_TIDY
)
if(ENABLE_CLANG_CPP_CHECKS)
include(ClangTidy)
enable_clang_tidy(
CHECKS
*
-abseil-*
-android-cloexec-fopen
# Yea we shouldn't be using rand()
-cert-msc30-c
-bugprone-exception-escape
-bugprone-macro-parentheses
-cert-env33-c
-cert-msc32-c
-cert-msc50-cpp
-cert-msc51-cpp
-cert-dcl37-c
-cert-dcl51-cpp
-clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-optin.performance.Padding
-clang-diagnostic-deprecated-declarations
-clang-diagnostic-extern-c-compat
-clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-avoid-c-arrays
-cppcoreguidelines-avoid-magic-numbers
-cppcoreguidelines-explicit-virtual-functions
-cppcoreguidelines-init-variables
-cppcoreguidelines-macro-usage
-cppcoreguidelines-non-private-member-variables-in-classes
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
-cppcoreguidelines-pro-bounds-constant-array-index
-cppcoreguidelines-pro-bounds-pointer-arithmetic
-cppcoreguidelines-pro-type-member-init
-cppcoreguidelines-pro-type-reinterpret-cast
-cppcoreguidelines-pro-type-union-access
-cppcoreguidelines-pro-type-vararg
-cppcoreguidelines-special-member-functions
-fuchsia-*
-google-explicit-constructor
-google-readability-braces-around-statements
-google-readability-todo
-google-runtime-int
-google-runtime-references
-hicpp-vararg
-hicpp-braces-around-statements
-hicpp-explicit-conversions
-hicpp-named-parameter
-hicpp-no-array-decay
# We really shouldn't use bitwise operators with signed integers, but
# opencl leaves us no choice
-hicpp-avoid-c-arrays
-hicpp-signed-bitwise
-hicpp-special-member-functions
-hicpp-uppercase-literal-suffix
-hicpp-use-auto
-hicpp-use-equals-default
-hicpp-use-override
-llvm-header-guard
-llvm-include-order
#-llvmlibc-*
-llvmlibc-restrict-system-libc-headers
-llvmlibc-callee-namespace
-llvmlibc-implementation-in-namespace
-llvm-else-after-return
-llvm-qualified-auto
-misc-misplaced-const
-misc-non-private-member-variables-in-classes
-misc-no-recursion
-modernize-avoid-bind
-modernize-avoid-c-arrays
-modernize-pass-by-value
-modernize-use-auto
-modernize-use-default-member-init
-modernize-use-equals-default
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-performance-unnecessary-value-param
-readability-braces-around-statements
-readability-else-after-return
# we are not ready to use it, but very useful
-readability-function-cognitive-complexity
-readability-isolate-declaration
-readability-magic-numbers
-readability-named-parameter
-readability-uppercase-literal-suffix
-readability-convert-member-functions-to-static
-readability-qualified-auto
-readability-redundant-string-init
# too many narrowing conversions in our code
-bugprone-narrowing-conversions
-cppcoreguidelines-narrowing-conversions
-altera-struct-pack-align
-cppcoreguidelines-prefer-member-initializer
${CK_TIDY_CHECKS}
${CK_TIDY_ERRORS}
HEADER_FILTER
"\.hpp$"
EXTRA_ARGS
-DCK_USE_CLANG_TIDY
)
include(CppCheck)
enable_cppcheck(
CHECKS
warning
style
performance
portability
SUPPRESS
ConfigurationNotChecked
constStatement
duplicateCondition
noExplicitConstructor
passedByValue
preprocessorErrorDirective
shadowVariable
unusedFunction
unusedPrivateFunction
unusedStructMember
unmatchedSuppression
FORCE
SOURCES
library/src
INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_BINARY_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/library/include
DEFINE
CPPCHECK=1
__linux__=1
)
include(CppCheck)
enable_cppcheck(
CHECKS
warning
style
performance
portability
SUPPRESS
ConfigurationNotChecked
constStatement
duplicateCondition
noExplicitConstructor
passedByValue
preprocessorErrorDirective
shadowVariable
unusedFunction
unusedPrivateFunction
unusedStructMember
unmatchedSuppression
FORCE
SOURCES
library/src
INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_BINARY_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/library/include
DEFINE
CPPCHECK=1
__linux__=1
)
else()
function(clang_tidy_check TARGET)
# stub out empty function if clang tidy is not enabled
endfunction()
endif()
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
@@ -557,12 +567,15 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS
add_compile_options(-fdiagnostics-color=always)
endif()
# make check runs the entire set of examples and tests
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
if(NOT MIOPEN_REQ_LIBS_ONLY)
# make check runs the entire set of examples and tests
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST")
# make regression runs the tests and examples that runs for more 30 seconds on gfx90a
add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST")
endif()
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
@@ -607,6 +620,7 @@ ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
add_subdirectory(library)

137
Jenkinsfile vendored
View File

@@ -114,6 +114,9 @@ def check_arch(){
else if ( runShell('grep -n "gfx908" rocminfo.log') ) {
arch_type = 6
}
else if ( runShell('grep -n "gfx950" rocminfo.log') ) {
arch_type = 7
}
return arch_type
}
@@ -132,6 +135,10 @@ def getDockerImage(Map conf=[:]){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using special docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
@@ -208,6 +215,11 @@ def cmake_build(Map conf=[:]){
def build_type_debug = (conf.get("build_type",'release') == 'debug')
// use special compiler for gfx950
if ( check_arch() == 7){
compiler = "/llvm-project/build/bin/clang++"
}
//cmake_env can overwrite default CXX variables.
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
@@ -263,6 +275,9 @@ def cmake_build(Map conf=[:]){
if (setup_args.contains("gfx94")){
invocation_tag="gfx94"
}
if (setup_args.contains("gfx95")){
invocation_tag="gfx95"
}
echo "invocation tag: ${invocation_tag}"
def redis_pre_setup_cmd = pre_setup_cmd
if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") {
@@ -422,16 +437,6 @@ def buildHipClangJob(Map conf=[:]){
env.HSA_ENABLE_SDMA=0
checkout scm
def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
@@ -455,7 +460,7 @@ def buildHipClangJob(Map conf=[:]){
echo "Docker flags: ${dockerOpts}"
def variant = env.STAGE_NAME
def image
def retimage
(retimage, image) = getDockerImage(conf)
@@ -496,17 +501,6 @@ def Build_CK(Map conf=[:]){
env.HSA_ENABLE_SDMA=0
env.DOCKER_BUILDKIT=1
checkout scm
def image
if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){
image = conf.get("docker_name", "")
echo "Using legacy docker: ${image}"
}
else{
image = getDockerImageName()
echo "Using default docker: ${image}"
}
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
@@ -527,6 +521,7 @@ def Build_CK(Map conf=[:]){
echo "Docker flags: ${dockerOpts}"
def variant = env.STAGE_NAME
def image
def retimage
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
@@ -638,6 +633,13 @@ def Build_CK(Map conf=[:]){
archiveArtifacts "perf_onnx_gemm_gfx908.log"
stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908"
}
else if ( arch == 7 ){
// run basic tests on gfx950
echo "Run performance tests"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950"
archiveArtifacts "perf_onnx_gemm_gfx950.log"
stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950"
}
}
}
if (params.hipTensor_test && arch == 1 ){
@@ -774,8 +776,8 @@ def process_results(Map conf=[:]){
}
//launch develop branch daily jobs
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
@@ -848,8 +850,8 @@ pipeline {
description: "Run the grouped conv large cases tests (default: OFF)")
booleanParam(
name: "RUN_CODEGEN_TESTS",
defaultValue: false,
description: "Run codegen tests (default: OFF)")
defaultValue: true,
description: "Run codegen tests (default: ON)")
booleanParam(
name: "RUN_CK_TILE_FMHA_TESTS",
defaultValue: false,
@@ -862,6 +864,10 @@ pipeline {
name: "RUN_CK_TILE_GEMM_TESTS",
defaultValue: false,
description: "Run the ck_tile GEMM tests (default: OFF)")
booleanParam(
name: "RUN_TILE_ENGINE_GEMM_TESTS",
defaultValue: false,
description: "Run the tile_engine_gemm tests (default: OFF)")
booleanParam(
name: "BUILD_INSTANCES_ONLY",
defaultValue: false,
@@ -870,6 +876,10 @@ pipeline {
name: "BUILD_GFX908",
defaultValue: false,
description: "Build CK and run tests on gfx908 (default: OFF)")
booleanParam(
name: "BUILD_GFX950",
defaultValue: false,
description: "Build CK and run tests on gfx950 (default: OFF)")
booleanParam(
name: "BUILD_GFX12",
defaultValue: true,
@@ -1145,6 +1155,48 @@ pipeline {
}
}
}
stage("Run TILE_ENGINE_GEMM Tests")
{
parallel
{
stage("Run TILE_ENGINE_GEMM Tests on gfx90a")
{
when {
beforeAgent true
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make benchmark_gemm -j && \
./bin/benchmark_gemm """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Run TILE_ENGINE_GEMM Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942") }
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
make benchmark_gemm -j && \
./bin/benchmark_gemm """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
}
}
stage("Build CK and run Tests")
{
@@ -1188,7 +1240,7 @@ pipeline {
cleanWs()
}
}
stage("Build CK for all gfx9 targets")
stage("Build CK and run Tests on gfx942")
{
when {
beforeAgent true
@@ -1203,6 +1255,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1210,6 +1263,29 @@ pipeline {
cleanWs()
}
}
stage("Build CK and run Tests on gfx950")
{
when {
beforeAgent true
expression { params.BUILD_GFX950.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent{ label rocmnode("gfx950") }
environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
-DGPU_TARGETS="gfx950" \
-DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx950" \
-DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "rocm/composable_kernel-private:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs()
}
}
stage("Build CK and run Tests on gfx908")
{
when {
@@ -1223,6 +1299,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx908" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1243,6 +1320,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1250,7 +1328,7 @@ pipeline {
cleanWs()
}
}
stage("Build CK instances for different targets")
stage("Build CK instances for all supported targets")
{
when {
beforeAgent true
@@ -1281,6 +1359,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx1030" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1301,6 +1380,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx1101" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{
@@ -1321,6 +1401,7 @@ pipeline {
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx1201" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
steps{

View File

@@ -48,6 +48,7 @@ rocm_install_targets(
INCLUDE include
)
rocm_export_targets(
TARGETS ck_host ck_headers
EXPORT ck_host_targets
NAMESPACE composable_kernel::
)

View File

@@ -1,2 +1,2 @@
rocm-docs-core[api_reference]==1.18.4
rocm-docs-core[api_reference]==1.20.0
sphinxcontrib-bibtex==2.6.3

View File

@@ -237,7 +237,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core[api-reference]==1.18.4
rocm-docs-core[api-reference]==1.20.0
# via -r requirements.in
rpds-py==0.24.0
# via

View File

@@ -3,7 +3,7 @@
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
import fnmatch
import itertools
from pathlib import Path
@@ -117,8 +117,50 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
FMHA_FWD_API="""
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{
#include <cstdio>
namespace {{
bool get_num_cus(unsigned& num_cu) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cu = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
float r = -1;
const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
{F_dispatch}
return r;
}}
@@ -134,36 +176,50 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return 'true'
else:
return f'{self.bool_expr}'
def __and__(self, other):
return CppConstraint(f'({str(self)}) && ({str(other)})')
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
constraint : CppConstraint
@property
def name(self) -> str:
@@ -220,17 +276,18 @@ class FmhaFwdApiTrait:
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
@@ -297,8 +354,8 @@ class FmhaFwdApiPool:
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
@@ -313,25 +370,27 @@ class FmhaFwdApiPool:
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
@@ -423,33 +482,21 @@ class FmhaFwdKernel:
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
class KernelComponentFactory:
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'128' : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
@staticmethod
def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
@@ -458,53 +505,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
if bias == "bias":
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'128' : [FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),]
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
d = CustomFactory.get_hdim_tile_size_dict(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
tiles = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not

View File

@@ -58,7 +58,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
{F_lse},
{F_dropout},
{F_squant},
{F_occupancy}>;
{F_occupancy},
{F_skip}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
@@ -94,7 +95,7 @@ using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
#include <iostream>
@@ -129,9 +130,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
@@ -160,11 +161,12 @@ class FmhaFwdApiTrait:
skpad : str
dpad : str
dvpad : str
skip : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
@property
def scheck(self) -> str:
@@ -227,6 +229,7 @@ class FmhaFwdPipeline:
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
@property
def name(self) -> str:
@@ -262,8 +265,12 @@ class FmhaFwdPipeline:
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_skip == 't' : n += '_skip'
else: n += '_nskip'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
return n
class FmhaFwdApiPool:
@@ -293,7 +300,7 @@ class FmhaFwdApiPool:
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] ,
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
@@ -381,6 +388,7 @@ class FmhaFwdKernel:
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
@@ -419,7 +427,8 @@ class FmhaFwdKernel:
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@@ -453,36 +462,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
# the below two is used for hdim vectorize load
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
@@ -508,7 +517,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
@@ -532,6 +541,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# PyTorch integration
@@ -540,6 +550,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
@@ -565,6 +576,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -169,6 +169,7 @@ struct fmha_fwd_args
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
float p_drop;
bool s_randval;
@@ -433,6 +434,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.window_size_left,
args.window_size_right,
args.mask_type,
args.min_seqlen_q,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
@@ -837,7 +839,8 @@ template <ck_tile::index_t HDim_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -861,6 +864,7 @@ struct fmha_fwd_traits_
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_>
@@ -995,6 +999,7 @@ struct fmha_fwd_traits
bool has_lse;
bool has_dropout;
bool do_fp8_static_quant;
bool skip_min_seqlen_q = false;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);

View File

@@ -214,4 +214,15 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -220,4 +220,11 @@ auto create_args(int argc, char* argv[])
}
// host API
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -178,7 +178,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time =
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =

View File

@@ -11,6 +11,7 @@
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
template <typename Pipeline, ck_tile::TailNumber TN>
void try_run(ck_tile::TailNumber tn)
@@ -74,64 +75,102 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto memory_operation = memory_operation_.value;
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
@@ -243,8 +282,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
return ave_time;
}
#include "run_gemm_example.inc"
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
@@ -345,7 +382,7 @@ int main(int argc, char* argv[])
{
try
{
run_gemm_example(argc, argv);
return !run_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{

View File

@@ -334,16 +334,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
try
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec == "fp32" && index_prec == "int32")
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
return r ? 0 : -1;
}

View File

@@ -320,4 +320,15 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_batched_gemm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -319,4 +319,15 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
#include "run_grouped_gemm_example.inc"
constexpr bool Persistent = false;
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_grouped_gemm_example<Persistent>(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -11,6 +11,7 @@
#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"
#include "run_flatmm_example.inc"
template <typename ADataType,
typename BDataType,
@@ -115,9 +116,47 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
float ave_time{0};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
if(args.k_batch == 1)
@@ -132,8 +171,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
}
}
#include "run_flatmm_example.inc"
int run_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -177,4 +214,15 @@ int run_flatmm_example(int argc, char* argv[])
return -1;
}
int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); }
int main(int argc, char* argv[])
{
try
{
return !run_flatmm_example(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -133,4 +133,11 @@ auto create_args(int argc, char* argv[])
}
// host API
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -122,7 +122,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
float ave_time =
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,
@@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,
@@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
{
const index_t GemmM = K;
const index_t GemmN = C * X * Y * Z;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM =
GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN =
GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
return transform_tensor_descriptor(
wei_grid_desc,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;

View File

@@ -351,6 +351,98 @@ struct Bilinear
float beta_;
};
struct AddClamp
{
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
: floor_(floor), ceil_(ceil){};
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
{
const float a = x0 + x1;
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
};
template <>
__host__ __device__ constexpr void
operator()<double, double, double>(double& y, const double& x0, const double& x1) const
{
const double a = x0 + x1;
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
};
template <>
__host__ __device__ constexpr void
operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
const half_t a = x0 + x1;
y = a > type_convert<half_t>(floor_)
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
: type_convert<half_t>(floor_);
};
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
const float a = x0 + x1;
y = a > type_convert<half_t>(floor_)
? (a < type_convert<half_t>(ceil_) ? a : type_convert<half_t>(ceil_))
: type_convert<half_t>(floor_);
};
template <>
__host__ __device__ constexpr void
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float a = x0 + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(floor_)
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
: type_convert<bhalf_t>(floor_);
};
template <>
__host__ __device__ constexpr void
operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float a = type_convert<float>(x0) + type_convert<float>(x1);
y = a > type_convert<bhalf_t>(floor_)
? (a < type_convert<bhalf_t>(ceil_) ? a : type_convert<bhalf_t>(ceil_))
: type_convert<bhalf_t>(floor_);
};
template <>
__host__ __device__ constexpr void
operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
};
template <>
__host__ __device__ constexpr void
operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
const int8_t a = x0 + x1;
y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
};
const float floor_;
const float ceil_;
};
struct AddRelu
{
template <typename Y, typename X0, typename X1>

View File

@@ -1,6 +1,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -166,8 +166,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * X;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -365,8 +365,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -558,8 +558,8 @@ struct TransformConvBwdWeightToGemm
const index_t GemmM = K;
const index_t GemmN = C * Z * X * Y;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =

View File

@@ -346,8 +346,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * X * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -534,8 +534,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * X * Y * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
@@ -737,8 +737,8 @@ struct TransformConvBwdWeightToGemmV2
const index_t GemmM = K * NumGroupsToMerge;
const index_t GemmN = C * Z * X * Y * NumGroupsToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =

View File

@@ -11,6 +11,7 @@
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
#define CHAR_BIT 8
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;

View File

@@ -55,8 +55,8 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_base.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"

View File

@@ -1437,8 +1437,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
@@ -1561,6 +1561,24 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 16)
{
thread_buffer<float, 8> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
}
else if constexpr(std::is_same<T, bf16_t>::value) // bf16
{
@@ -1597,6 +1615,24 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
return bit_cast<rtn_type>(tmp);
}
else if constexpr(N == 16)
{
thread_buffer<float, 8> tmp;
tmp.template get_as<fp32x4_t>()(number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.template get_as<fp32x4_t>()(number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return bit_cast<rtn_type>(tmp);
}
}
else // other datatype
{

View File

@@ -35,4 +35,7 @@
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_utils.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/flush_icache.hpp"
#include "ck_tile/host/rotating_buffers.hpp"

View File

@@ -0,0 +1,56 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
#include <string>
#include <string_view>
#include <hip/hip_runtime.h>
namespace ck_tile {
constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u)
{
return str.empty() ? h
: fnv1a_hash(str.substr(1),
(h ^ static_cast<unsigned char>(str.front())) * 16777619u);
}
inline std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string raw_name(props.gcnArchName);
const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
switch(fnv1a_hash(name))
{
// https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
case fnv1a_hash("Ellesmere"):
case fnv1a_hash("Baffin"):
case fnv1a_hash("RacerX"):
case fnv1a_hash("Polaris10"):
case fnv1a_hash("Polaris11"):
case fnv1a_hash("Tonga"):
case fnv1a_hash("Fiji"):
case fnv1a_hash("gfx800"):
case fnv1a_hash("gfx802"):
case fnv1a_hash("gfx804"): return "gfx803";
case fnv1a_hash("Vega10"):
case fnv1a_hash("gfx901"): return "gfx900";
case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030";
default: return name;
}
}
} // namespace ck_tile
#endif

View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
namespace ck_tile {
static __global__ void flush_cache()
{
asm __volatile__("s_icache_inv \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t" ::
:);
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -11,6 +11,13 @@
#include <cstddef>
namespace ck_tile {
#define LOW_CU_PROCESSORS 80
#define HIGH_CU_PROCESSORS 228
#define OPTIMAL_LATENCY_LOW_CU_PROCESSORS 0.005
#define OPTIMAL_LATENCY_HIGH_CU_PROCESSORS 0.0015
#define OPTIMAL_LATENCY_SAFE_MARGIN 0.01
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
@@ -81,6 +88,8 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
template <typename... Callables>
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables)
{
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
if(!s.time_kernel_)
{
launch_and_check(s, std::forward<Callables>(callables)...);
@@ -88,7 +97,7 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
}
auto time_launches = [&](auto timer) {
// warmup
// Warmup
for(int i = 0; i < s.cold_niters_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
@@ -114,4 +123,53 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
}
}
template <typename PreprocessFunc, typename... Callables>
CK_TILE_HOST float launch_kernel_preprocess(const stream_config& s,
PreprocessFunc preprocess,
Callables&&... callables)
{
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
if(!s.time_kernel_)
{
preprocess();
launch_and_check(s, std::forward<Callables>(callables)...);
return 0;
}
auto time_launches = [&](auto timer) {
// Warmup
for(int i = 0; i < s.cold_niters_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++)
{
preprocess();
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.stop(s.stream_id_);
hipDeviceProp_t deviceProps;
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
float preprocess_offset = (deviceProps.multiProcessorCount >= HIGH_CU_PROCESSORS)
? OPTIMAL_LATENCY_HIGH_CU_PROCESSORS
: (deviceProps.multiProcessorCount == LOW_CU_PROCESSORS)
? OPTIMAL_LATENCY_LOW_CU_PROCESSORS
: OPTIMAL_LATENCY_SAFE_MARGIN;
return (timer.duration() - preprocess_offset * s.nrepeat_) / s.nrepeat_;
};
if(s.is_gpu_timer_)
{
return time_launches(gpu_timer{});
}
else
{
return time_launches(cpu_timer{});
}
}
} // namespace ck_tile

View File

@@ -0,0 +1,102 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
template <typename ADataType, typename BDataType>
struct RotatingMemWrapper
{
RotatingMemWrapper() = delete;
RotatingMemWrapper(const void* a_ptr_,
const void* b_ptr_,
std::size_t rotating_count_,
std::size_t size_a_,
std::size_t size_b_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
rotating_count(rotating_count_),
size_a(size_a_),
size_b(size_b_)
{
p_a_grids.push_back(a_ptr);
p_b_grids.push_back(b_ptr);
for(size_t i = 1; i < rotating_count; i++)
{
{
void* pADeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]),
size_a_,
hipMemcpyDeviceToDevice));
p_a_grids.push_back(pADeviceBuf);
}
{
void* pBDeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]),
size_b_,
hipMemcpyDeviceToDevice));
p_b_grids.push_back(pBDeviceBuf);
}
}
}
void Next()
{
if(rotating_count > 1)
{
std::size_t idx = iter++ % rotating_count;
a_ptr = p_a_grids[idx];
b_ptr = p_b_grids[idx];
}
}
void Print()
{
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
<< ", rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapper() noexcept
{
if(rotating_count > 1)
{
// restore ptr
a_ptr = p_a_grids[0];
b_ptr = p_b_grids[0];
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
}
}
}
private:
const void* a_ptr;
const void* b_ptr;
std::size_t iter = 0;
std::size_t rotating_count = 1;
std::size_t size_a = 0;
std::size_t size_b = 0;
std::vector<const void*> p_a_grids;
std::vector<const void*> p_b_grids;
};
inline void flush_icache()
{
hipDeviceProp_t deviceProps;
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
ck_tile::flush_cache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
HIP_CHECK_ERROR(hipGetLastError());
}
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -30,5 +30,7 @@ struct stream_config
int cold_niters_ = 3;
int nrepeat_ = 10;
bool is_gpu_timer_ = true; // keep compatible
bool flush_cache_ = false;
int rotating_count_ = 1;
};
} // namespace ck_tile

View File

@@ -53,6 +53,8 @@ struct FmhaFwdKernel
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
@@ -257,6 +259,11 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_randval = 0;
};
struct FmhaFwdSkipMinSeqlenQKargs
{
ck_tile::index_t min_seqlen_q = 0;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
@@ -287,7 +294,8 @@ struct FmhaFwdKernel
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -664,6 +672,7 @@ struct FmhaFwdKernel
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
@@ -698,6 +707,7 @@ struct FmhaFwdKernel
{}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
{}, // placeholder for logits_soft_cap
{}, // placeholder for min_seqlen_q
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -753,6 +763,10 @@ struct FmhaFwdKernel
{
kargs.init_logits_soft_cap(logits_soft_cap);
}
if constexpr(kSkipMinSeqlenQ)
{
kargs.min_seqlen_q = min_seqlen_q;
}
return kargs;
}
@@ -969,7 +983,15 @@ struct FmhaFwdKernel
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
else
{
@@ -989,7 +1011,15 @@ struct FmhaFwdKernel
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
}
}
@@ -1053,6 +1083,14 @@ struct FmhaFwdKernel
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if constexpr(kSkipMinSeqlenQ)
{
if(kargs.seqlen_q <= kargs.min_seqlen_q)
{
return;
}
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)

View File

@@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
if constexpr(kHasMask)
{
// assume that num_tile_n1 is always 1
return ck_tile::make_tuple(
(gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
else
{
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
}
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }

View File

@@ -53,6 +53,7 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;

View File

@@ -19,7 +19,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
struct TileFmhaTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
@@ -33,6 +34,7 @@ struct TileFmhaTraits
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,

View File

@@ -120,6 +120,7 @@ using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
using AddSilu = ck::tensor_operation::element_wise::AddSilu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu;

View File

@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -25,7 +25,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -25,7 +25,7 @@ using Empty_Tuple = ck::Tuple<>;
using namespace ck::tensor_layout::convolution;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
static constexpr auto ConvFwdDefault =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

View File

@@ -13,7 +13,7 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_USE_XDL
#include "grouped_convolution_forward_bias_relu_xdl.inc"
#include "grouped_convolution_forward_bias_clamp_xdl.inc"
#endif
namespace ck {
@@ -44,7 +44,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddRelu,
ck::tensor_operation::element_wise::AddClamp,
AComputeType,
BComputeType>>
{
@@ -60,7 +60,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
OutDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddRelu,
ck::tensor_operation::element_wise::AddClamp,
AComputeType,
BComputeType>;
@@ -80,23 +80,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif
@@ -112,19 +112,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<AComputeType, ck::bhalf_t> &&
is_same_v<BComputeType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
op_ptrs);
add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
op_ptrs);
}
#endif

View File

@@ -10,7 +10,7 @@ namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -22,9 +22,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -36,9 +36,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -50,9 +50,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -64,9 +64,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -78,9 +78,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -92,9 +92,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -106,9 +106,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -120,9 +120,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -134,9 +134,9 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -148,9 +148,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -162,9 +162,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -176,9 +176,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhw
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -190,9 +190,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -204,9 +204,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -218,9 +218,9 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -232,7 +232,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances);
AddClamp>>>& instances);
#endif

View File

@@ -104,17 +104,6 @@ function(add_instance_library INSTANCE_NAME)
endif()
endforeach()
if(MIOPEN_REQ_LIBS_ONLY)
message("Removing all sources that are not required for MIOpen")
foreach(source IN LISTS ARGN)
if(source MATCHES "gemm" OR
source MATCHES "mha" OR
source MATCHES "contraction" OR
source MATCHES "reduce")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
#message("remaining instances: ${ARGN}")
#only continue if there are some source files left on the list
if(ARGN)
@@ -180,7 +169,7 @@ function(add_instance_library INSTANCE_NAME)
target_compile_features(${INSTANCE_NAME} PUBLIC)
# flags to compress the library
if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
if(NOT DISABLE_OFFLOAD_COMPRESS AND NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132)
#message("Adding --offload-compress flag for ${INSTANCE_NAME}")
target_compile_options(${INSTANCE_NAME} PRIVATE --offload-compress)
endif()
@@ -294,6 +283,17 @@ FOREACH(subdir_path ${dir_list})
message("Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
set(add_inst 0)
endif()
if(MIOPEN_REQ_LIBS_ONLY)
message("Removing all sources that are not required for MIOpen")
if("${cmake_instance}" MATCHES "gemm" OR
"${cmake_instance}" MATCHES "mha" OR
"${cmake_instance}" MATCHES "contraction" OR
"${cmake_instance}" MATCHES "reduce")
set(add_inst 0)
endif()
endif()
if((add_inst EQUAL 1))
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})

View File

@@ -0,0 +1,16 @@
# ONLY XDL_KERNELS
add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp
)

View File

@@ -10,7 +10,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -22,7 +22,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
if(ck::get_device_name() == "gfx950")
{
@@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
@@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
@@ -57,7 +57,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
}

View File

@@ -10,7 +10,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -22,7 +22,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
@@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
@@ -42,7 +42,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances<2,
@@ -52,7 +52,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -10,7 +10,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -22,7 +22,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
if(ck::get_device_name() != "gfx950")
{
@@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
@@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
@@ -57,7 +57,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
}

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
@@ -31,7 +31,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
@@ -41,7 +41,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_16x16_instances<2,
@@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_instances<2,
@@ -31,7 +31,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_instances<2,
@@ -41,7 +41,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
NHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_instances<2,
@@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance
NHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
@@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
@@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
ConvFwdDefault,
Interwave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
@@ -43,7 +43,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
ConvFwd1x1P0,
Interwave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
@@ -54,7 +54,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte
ConvFwd1x1S1P0,
Interwave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
@@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
ConvFwdDefault,
Intrawave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
@@ -43,7 +43,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
ConvFwd1x1P0,
Intrawave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<2,
@@ -54,7 +54,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr
ConvFwd1x1S1P0,
Intrawave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -10,7 +10,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
NHWGC,
GKYXC,
@@ -22,7 +22,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
if(ck::get_device_name() == "gfx950")
{
@@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
@@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
NHWGK,
ConvFwd3x3,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
else
{
@@ -59,7 +59,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
NHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
@@ -70,7 +70,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk
NHWGK,
ConvFwd3x3,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
}

View File

@@ -1,16 +0,0 @@
# ONLY XDL_KERNELS
add_instance_library(device_grouped_conv2d_fwd_bias_relu_instance
xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp
)

View File

@@ -0,0 +1,16 @@
# ONLY XDL_KERNELS
set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp
)
add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD})

View File

@@ -10,7 +10,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -22,7 +22,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
NDHWGC,
@@ -41,7 +41,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances<3,
NDHWGC,
@@ -50,7 +50,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
if(ck::get_device_name() != "gfx950")
{
@@ -63,7 +63,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
@@ -73,7 +73,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3,
@@ -83,7 +83,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
if(ck::get_device_name() == "gfx950")
@@ -97,7 +97,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
@@ -107,7 +107,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3,
@@ -117,7 +117,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
}

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
@@ -31,7 +31,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
NDHWGC,
@@ -40,7 +40,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_16x16_instances<3,
NDHWGC,
@@ -49,7 +49,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_instances<3,
@@ -31,7 +31,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_instances<3,
NDHWGC,
@@ -40,7 +40,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
NDHWGK,
ConvFwd1x1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_instances<3,
NDHWGC,
@@ -49,7 +49,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta
NDHWGK,
ConvFwd1x1S1P0,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhw
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhw
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
ConvFwdDefault,
Interwave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
NDHWGC,
@@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
ConvFwd1x1P0,
Interwave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
NDHWGC,
@@ -52,7 +52,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
ConvFwd1x1S1P0,
Interwave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
ConvFwdDefault,
Intrawave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
NDHWGC,
@@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
ConvFwd1x1P0,
Intrawave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(instances,
device_grouped_conv_fwd_xdl_bf16_mem_instances<3,
NDHWGC,
@@ -52,7 +52,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i
ConvFwd1x1S1P0,
Intrawave,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
GKZYXC,
@@ -21,7 +21,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
AddClamp>>>& instances)
{
add_device_operation_instances(
instances,
@@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
NDHWGK,
ConvFwdDefault,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3,
@@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh
NDHWGK,
ConvFwd3x3,
Tuple<BF16>,
AddRelu>{});
AddClamp>{});
}
} // namespace instance

View File

@@ -1,16 +0,0 @@
# ONLY XDL_KERNELS
set(GROUPED_CONV3D_FWD
xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp
xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp
xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp
)
add_instance_library(device_grouped_conv3d_fwd_bias_relu_instance ${GROUPED_CONV3D_FWD})

View File

@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
@@ -35,19 +35,22 @@ template <ck::index_t NDimSpatial,
typename AComputeType = InDataType,
typename BComputeType = AComputeType,
typename IndexType = ck::index_t>
bool profile_grouped_conv_fwd_bias_relu_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddRelu;
using OutElementOp = ck::tensor_operation::element_wise::AddClamp;
const float floor = 0.f;
const float ceil = 256.f;
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
const auto out_element_op = OutElementOp{floor, ceil};
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);

View File

@@ -251,7 +251,7 @@ add_subdirectory(reduce)
add_subdirectory(convnd_fwd)
add_subdirectory(convnd_bwd_data)
add_subdirectory(grouped_convnd_fwd)
add_subdirectory(grouped_convnd_fwd_bias_relu)
add_subdirectory(grouped_convnd_fwd_bias_clamp)
add_subdirectory(grouped_convnd_bwd_weight)
add_subdirectory(block_to_ctile_map)
add_subdirectory(softmax)

View File

@@ -1,4 +1,4 @@
# Currently ck_tile is only built on gfx94/gfx95
# Currently ck_tile_gemm is only built on gfx94/gfx95
set(EXAMPLE_GEMM_COMPILE_OPTIONS "")
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
@@ -12,8 +12,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS
-enable-noalias-to-md-conversion=0
)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp)
add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp)
@@ -25,4 +23,3 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
else()
message("Skipping ck_tile_gemm tests for current target")
endif()
endif()

View File

@@ -0,0 +1,4 @@
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp)
target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance)
endif()

View File

@@ -7,11 +7,11 @@
#include <vector>
#include <gtest/gtest.h>
#include "profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp"
#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
template <typename Tuple>
class TestGroupedConvndFwd : public ::testing::Test
@@ -32,16 +32,16 @@ class TestGroupedConvndFwd : public ::testing::Test
bool pass = true;
for(auto& param : conv_params)
{
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_relu_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
IndexType>(
pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DataType,
DataType,
DataType,
DataType,
DataType,
IndexType>(
true, // do_verification
1, // init_method: integer value
false, // do_log

View File

@@ -1,4 +0,0 @@
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_grouped_convnd_fwd_bias_relu test_grouped_convnd_fwd_bias_relu.cpp)
target_link_libraries(test_grouped_convnd_fwd_bias_relu PRIVATE utility device_grouped_conv2d_fwd_bias_relu_instance device_grouped_conv3d_fwd_bias_relu_instance)
endif()

View File

@@ -1,43 +1,60 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--list_blobs
RESULT_VARIABLE ret
)
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
message( FATAL_ERROR "Fail to list kernels via Python. ${ret}")
endif()
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS)
set(GEMM_CODEGEN_CPP_FILES "")
set(GEMM_CODEGEN_HPP_FILES "")
foreach(blob ${GEMM_CODEGEN_BLOBS})
string(STRIP "${blob}" stripped_blob)
if(stripped_blob MATCHES "\\.cpp$")
list(APPEND GEMM_CODEGEN_CPP_FILES "${stripped_blob}")
elseif(stripped_blob MATCHES "\\.hpp$")
list(APPEND GEMM_CODEGEN_HPP_FILES "${stripped_blob}")
endif()
endforeach()
add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt
${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json
)
set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm")
message("adding example ${EXECUTABLE_GEMM_INSTANCE}")
add_library(gemm_template_instances OBJECT EXCLUDE_FROM_ALL ${GEMM_CODEGEN_CPP_FILES})
# Explicitly set LINKER_LANGUAGE to avoid build config failures with Ninja.
set_target_properties(gemm_template_instances PROPERTIES LINKER_LANGUAGE CXX)
target_include_directories(gemm_template_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(gemm_template_instances PRIVATE ${GEMM_CODEGEN_HPP_FILES})
set(BENCHMARK_GEMM_EXECUTABLE "benchmark_gemm")
message("adding example ${BENCHMARK_GEMM_EXECUTABLE}")
# use build as include directory
include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp)
target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS})
add_library(gemm_host_api INTERFACE EXCLUDE_FROM_ALL)
target_include_directories(gemm_host_api INTERFACE ${CMAKE_CURRENT_LIST_DIR})
target_sources(gemm_host_api INTERFACE ${GEMM_CODEGEN_HPP_FILES} gemm_host_api.hpp)
target_link_libraries(gemm_host_api INTERFACE gemm_template_instances)
add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp)
target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp)
target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api)
set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS)
@@ -46,6 +63,6 @@ list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS
-Wno-float-equal
--offload-compress)
target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
target_compile_options(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS})
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -4,10 +4,11 @@ CK Tile Engine GEMM is used to generate and run GEMM kernels with different comb
# Kernel Configurations
Kernel parameters are specified in the `instance_combination.json` file, including matrix layouts, data types, padding settings, pipelines, schedulers, epilogues, and numerical values for tile and warp sizes.
User can provide kernel configuration such as tile size, warp size, padding, pipeline, scheduler and epilogue in the config file with limited values. For reference please see `./configs/user_provided_config.json`.
Given a valid set of values, tile_engine_gemm will automatically iterate over all possible combinations of BlockTile and WarpTile sizes, as well as the specified pipelines, schedulers, and epilogues from `./configs/instance_combination.json`, and build the corresponding kernels.
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
## Build Instructions
``` bash
@@ -16,41 +17,45 @@ mkdir build && cd build
# build composable kernel
sh ../script/cmake-ck-dev.sh ../ <arch> # replace <arch> with the appropriate architecture (example gfx942) or leave blank
# generate the executable
make tile_engine_gemm -j
make benchmark_gemm -j
```
`tile_engine_gemm` will be located in the `./bin/` directory.
`benchmark_gemm` will be located in the `./bin/` directory.
`benchmark_gemm` must be rebuilt everytime if configuration file is modified.
_`tile_engine_gemm` must be rebuilt everytime `instance_combination.json` is modified._
``` bash
rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild
rm -rf tile_engine/ && make benchmark_gemm -j # rebuild
```
## tile_engine_gemm inputs
## benchmark_gemm inputs
```
-m The value for m dimension. Default is 3840.
-n The value for n dimension. Default is 4096.
-k The value for k dimension. Default is 2048.
-stride_a The stride value for tensor A. Default is 0.
-stride_b The stride value for tensor B. Default is 0.
-stride_c The stride value for tensor C Default is 0.
-split_k The split value for k dimension. Default is 1.
-v The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 2, validation on GPU.
-log Wether output kernel instance information or not. Possible values are true or false. Default is false.
-warmup The number of iterations before benchmark the kernel. Default is 50.
-repeat The number of iterations to benchmark the kernel. Default is 100.
-timer Whether if the timer is gpu timer or not. Possible values are true or false. Default is true.
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
-flush_cache To flush cache in between different runs.Possible values are true or false. Default is false.
-rotating_count count to flush cache. Default is 5.
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
-csv_filename The filename of benchmark result. Default is gemm_kernel.
-structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false.
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.
-pad_n Whether pad or not in n direction. Possible values are true or false. Default is false.
-pad_k Whether pad or not in k direction. Possible values are true or false. Default is false.
-m m dimension (default:3840)
-n n dimension (default:4096)
-k k dimension (default:2048)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-split_k SplitK value (default:1)
-v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2)
-warmup Number of iterations before benchmark the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0)
-structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0)
-pipeline possible values are: compv3, compv4, mem (default:compv3)
-scheduler possible values are: intrawave, interwave (default:intrawave)
-epilogue possible values are: cshuffle, default (default:cshuffle)
-pad_m Pad in m direction - true/false (default:false)
-pad_n Pad in n direction - true/false (default:false)
-pad_k Pad in k direction - true/false (default:false)
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json
```
Note: In `./configs/instance_combination.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
## Example
@@ -86,7 +91,7 @@ The following JSON file specifies parameters used to generate and build GEMM ker
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
``` bash
./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
./bin/benchmark_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default
```
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.

View File

@@ -0,0 +1,68 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <tuple>
#include <exception>
#include "gemm_profiler.hpp"
#include "benchmark_gemm.hpp"
void benchmark_gemm(const ck_tile::ArgParser& arg_parser)
{
GemmProblem gemm_problem{arg_parser.get_int("split_k"),
arg_parser.get_int("m"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("stride_a"),
arg_parser.get_int("stride_b"),
arg_parser.get_int("stride_c"),
DataTypeTraits<ADataType>::name,
DataTypeTraits<BDataType>::name,
DataTypeTraits<AccDataType>::name,
DataTypeTraits<CDataType>::name,
ALayout::name,
BLayout::name,
CLayout::name,
arg_parser.get_bool("structured_sparsity")};
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count")};
auto& profiler = GemmProfiler::instance(setting);
try
{
auto kernel_func = get_kernel_func_by_trait(arg_parser);
profiler.benchmark(gemm_problem, kernel_func);
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
}
catch(const std::exception& e)
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
}
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
benchmark_gemm(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,235 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include "gemm_host_api.hpp"
enum class Metric
{
LATENCY = 0,
TFLOPS = 1,
BANDWIDTH = 2
};
inline constexpr auto get_metric_name(Metric m)
{
switch(m)
{
case Metric::LATENCY: return "latency";
case Metric::TFLOPS: return "tflops";
case Metric::BANDWIDTH: return "bandwidth";
default: throw std::invalid_argument("Unsupported metric type");
}
}
struct GemmProblem
{
int split_k_;
int m_, n_, k_;
int stride_a_, stride_b_, stride_c_;
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
std::string layout_a_, layout_b_, layout_c_;
bool structured_sparsity_;
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
{
os << "{\n"
<< " \"split_k\":" << problem.split_k_ << ",\n"
<< " \"m\":" << problem.m_ << ",\n"
<< " \"n\":" << problem.n_ << ",\n"
<< " \"k\":" << problem.k_ << ",\n"
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_c\":\"" << problem.layout_c_ << "\"\n"
<< " \"structured_sparsity\":\"" << problem.structured_sparsity_ << "\"\n"
<< "}";
return os;
}
};
struct PerformanceResult
{
double latency_;
double tflops_;
double bandwidth_;
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
case Metric::LATENCY: return a.latency_ < b.latency_;
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
default: throw std::invalid_argument("Unsupported metric type");
}
}
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
<< "}";
return os;
}
};
struct KernelInstance
{
std::string name_;
GemmProblem problem_;
PerformanceResult perf_result_;
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
}
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{\n"
<< " \"name\": \""
<< "{\n"
<< obj.name_ << "\n}"
<< "\",\n"
<< " \"problem\": \"" << obj.problem_ << "\",\n"
<< " \"perf_result\": " << obj.perf_result_ << "\n"
<< "}";
return os;
}
};
struct Setting
{
int n_warmup_;
int n_repeat_;
bool is_gpu_timer_;
int verify_;
int init_method_;
bool log_;
std::string csv_filename_;
bool flush_cache_;
int rotating_count_;
};
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
{
std::string version;
std::getline(version_file, version);
return version;
}
return "Unknown";
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations
bool compare(ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
void gemm_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
{
if(verify == 1)
{
c_m_n_host_result.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_result);
}
else if(verify == 2)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
c_m_n_host_result.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
}
}

View File

@@ -0,0 +1,239 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Mappings and utility functions for kernel code generation.
"""
import subprocess
import re
from functools import lru_cache
DATA_TYPE_MAP = {'fp32': 'float',
'fp16': 'ck_tile::half_t',
'bf16': 'ck_tile::bf16_t',
'int8': 'ck_tile::int8_t',
'fp8': 'ck_tile::fp8_t',
'bf8': 'ck_tile::bf8_t',
'int4': 'ck_tile::pk_int4_t'
}
LAYOUT_MAP = {'r': 'ck_tile::tensor_layout::gemm::RowMajor',
'c': 'ck_tile::tensor_layout::gemm::ColumnMajor'}
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
kPadM,
kPadN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
"""
CSHUFFLE_EPILOGUE = """
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
CDataType,
CLayout,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
WarpM,
WarpN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC,
memory_operation>>;
"""
HOT_LOOP_FALSE = """
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
}
"""
RUN_MEM = """
// Handle One and Full cases directly
if (tail_num == ck_tile::TailNumber::One) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
} else if (tail_num == ck_tile::TailNumber::Full) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
auto check_tail = [&](auto... TNs) {
([&]{
if constexpr(BaseGemmPipeline::PrefetchStages > static_cast<int>(decltype(TNs)::value)) {
if(tail_num == decltype(TNs)::value) {
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, decltype(TNs)::value>{});
}
}
}(), ...);
};
check_tail(
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
);
"""
RUN_COMPV3 = """
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
}
"""
RUN_COMPV4 = """
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
"""
PIPELINE_MAP = {'mem': ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
'compv3': ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
'compv4': ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
SCHEDULER_MAP = {'interwave': 'ck_tile::GemmPipelineScheduler::Interwave',
'intrawave': 'ck_tile::GemmPipelineScheduler::Intrawave'}
EPILOGUE_MAP = {'default': DEFAULT_EPILOGUE,
'cshuffle': CSHUFFLE_EPILOGUE}
HOT_LOOP_TRUE = {'mem': RUN_MEM,
'compv3': RUN_COMPV3,
'compv4': RUN_COMPV4}
def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
# To Do: add some more supported combinations
warp_tile_supported_combinations = {
"gfx90a": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
},
"gfx942": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
},
"gfx950": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}
}
# To Do: remove some unsupported combinations
trait_unsupported_combinations = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave")
}
def element_size(data_type: str) -> float:
"""Calculate the size (in bytes) of a single element for given data type."""
data_type = data_type.lower()
if data_type in {'fp16', 'bf16'}:
return 2
elif data_type in {'int8', 'fp8', 'bf8'}:
return 1
elif data_type == 'int4':
return 0.5
else:
raise ValueError(f"Unsupported data type: {data_type}")
GPU_NAME_PATTERN = re.compile(r'Name:\s*(gfx\d+\w*)')
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"],
text=True,
stderr=subprocess.PIPE,
timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
print("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
print("GPU query timeout (5s)")
except Exception as e:
print(f"GPU detection error: {str(e)}")
return ""

View File

@@ -0,0 +1,130 @@
{
"problem": {
"layout_a": {
"values": [
"r"
]
},
"layout_b": {
"values": [
"c"
]
},
"layout_c": {
"values": [
"r"
]
},
"datatype_a": {
"values": [
"fp16"
]
},
"datatype_b": {
"values": [
"fp16"
]
},
"datatype_c": {
"values": [
"fp16"
]
}
},
"tile_config": {
"tile_m": {
"max": 512,
"min": 64,
"step": 64,
"exclude": []
},
"tile_n": {
"max": 512,
"min": 64,
"step": 32,
"exclude": []
},
"tile_k": {
"max": 512,
"min": 64,
"step": 64,
"exclude": []
},
"warp_m": {
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16,
32
]
},
"warp_tile_n": {
"values": [
16,
32
]
},
"warp_tile_k": {
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv4",
"compv3",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -1,62 +0,0 @@
{
"architecture": {
"values": ["gfx90a"]
},
"layout_a": {
"values": ["r"]
},
"layout_b": {
"values": ["c"]
},
"layout_c": {
"values": ["r"]
},
"datatype": {
"values": ["fp16"]
},
"tile_m": {
"values": [256]
},
"tile_n": {
"values": [256]
},
"tile_k": {
"values": [32]
},
"warp_m": {
"values": [2]
},
"warp_n": {
"values": [2]
},
"warp_k": {
"values": [1]
},
"warp_tile_m": {
"values": [32]
},
"warp_tile_n": {
"values": [32]
},
"warp_tile_k": {
"values": [16]
},
"kPadM": {
"values": [false]
},
"kPadN": {
"values": [false]
},
"kPadK": {
"values": [false]
},
"pipeline": {
"values": ["compv3", "compv4", "mem"]
},
"scheduler": {
"values": ["intrawave", "interwave"]
},
"epilogue": {
"values": ["default", "cshuffle"]
}
}

View File

@@ -0,0 +1,116 @@
{
"problem": {
"layout_a": {
"values": [
"r"
]
},
"layout_b": {
"values": [
"c"
]
},
"layout_c": {
"values": [
"r"
]
},
"datatype_a": {
"values": [
"fp16"
]
},
"datatype_b": {
"values": [
"fp16"
]
},
"datatype_c": {
"values": [
"fp16"
]
}
},
"tile_config": {
"tile_m": {
"values": [
128
]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
32
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
32
]
},
"warp_tile_n": {
"values": [
32
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"default",
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -1,192 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "gemm_common.hpp"
#include "gemm_dispatcher.hpp"
#include "gemm_host_api.hpp"
void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
int verify,
bool structured_sparsity,
KernelTraits& trait,
ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& stream)
{
return GemmDispatcher::dispatch(c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
verify,
structured_sparsity,
trait,
args,
stream);
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
void run(const ck_tile::ArgParser& arg_parser)
{
const ALayout a_layout = ALayout{};
const BLayout b_layout = BLayout{};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
int verify = arg_parser.get_int("v");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
}
if(structured_sparsity)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs gemm_args;
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
gemm_args.k_batch = kbatch;
gemm_args.M = M;
gemm_args.N = N;
gemm_args.K = K;
gemm_args.stride_A = stride_A;
gemm_args.stride_B = stride_B;
gemm_args.stride_C = stride_C;
KernelTraits trait;
trait.pipeline = arg_parser.get_str("pipeline");
trait.scheduler = arg_parser.get_str("scheduler");
trait.epilogue = arg_parser.get_str("epilogue");
trait.kPadM = arg_parser.get_bool("pad_m");
trait.kPadN = arg_parser.get_bool("pad_n");
trait.kPadK = arg_parser.get_bool("pad_k");
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name << std::endl;
ck_tile::HostTensor<CDataType> c_m_n_host_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
if(verify)
{
gemm_host_reference<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(verify,
a_m_k,
b_k_n,
c_m_n_host_result,
a_m_k_dev_buf,
b_k_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C);
}
gemm_kernel_launch(c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
verify,
structured_sparsity,
trait,
gemm_args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
return;
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
run<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

218
tile_engine/ops/gemm/gemm_host_api.hpp Executable file → Normal file
View File

@@ -1,16 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#pragma once
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/ops/gemm.hpp"
#pragma once
#include "ck_tile/host.hpp"
#include "gemm_dispatcher.hpp"
#include "gemm_common.hpp"
template <typename T>
struct DataTypeTraits;
@@ -57,24 +56,6 @@ struct DataTypeTraits<ck_tile::pk_int4_t>
static constexpr const char* name = "pk_int4_t";
};
/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a
/// specific kernel instance based on the provided settings.
struct KernelTraits
{
/// @brief The name of the pipeline.
std::string pipeline;
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
std::string scheduler;
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
std::string epilogue;
/// @brief Indicates whether padding is applied to the M dimension.
bool kPadM;
/// @brief Indicates whether padding is applied to the N dimension.
bool kPadN;
/// @brief Indicates whether padding is applied to the K dimension.
bool kPadK;
};
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
@@ -82,49 +63,76 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("split_k", "1", "splitK value")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("structured_sparsity", "0", "0:false, 1:true")
.insert("pipeline", "compv3", "compv3, compv4, mem")
.insert("scheduler", "intrawave", "intrawave, interwave")
.insert("epilogue", "cshuffle", "cshuffle, default")
.insert("pad_m", "false", "true, false")
.insert("pad_n", "false", "true, false")
.insert("pad_k", "false", "true, false");
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
.insert("n", "4096", "The value for n dimension. Default is 4096.")
.insert("k", "2048", "The value for k dimension. Default is 2048.")
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
.insert("stride_c", "0", "The stride value for tensor C Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("verify",
"2",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
"for validation on GPU. Default is 2, validation on GPU.")
.insert("log",
"false",
"Wether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert(
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
.insert(
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
.insert("timer",
"true",
"Whether if the timer is gpu timer or not. Possible values are false or true. "
"Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"false",
"To flush cache, possible values are true or false. "
"Default is false.")
.insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 0, latency.")
.insert("csv_filename",
"gemm_kernel",
"The filename of benchmark result. Default is gemm_kernel.")
.insert("structured_sparsity",
"false",
"Whether use sparsity kernel or not. Possible values are true or false. Default is "
"false")
.insert(
"pipeline",
"compv3",
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
.insert("scheduler",
"intrawave",
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
"compv3.")
.insert(
"epilogue",
"cshuffle",
"The type of epilogue. Possible values are cshuffle or default. Default is csshuffle.")
.insert("pad_m",
"false",
"Whether pad or not in m direction. Possible values are true or false. Default is "
"false.")
.insert("pad_n",
"false",
"Whether pad or not in n direction. Possible values are true or false. Default is "
"false.")
.insert("pad_k",
"false",
"Whether pad or not in k direction. Possible values are true or false. Default is "
"false.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -185,79 +193,17 @@ void permute_vectors_i4x4_b(Tensor& tensor)
}
}
/// @brief Function to compare the results of the device and host computations
void compare(ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
{
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
KernelTraits trait;
trait.pipeline = arg_parser.get_str("pipeline");
trait.scheduler = arg_parser.get_str("scheduler");
trait.epilogue = arg_parser.get_str("epilogue");
trait.pad_m = arg_parser.get_bool("pad_m");
trait.pad_n = arg_parser.get_bool("pad_n");
trait.pad_k = arg_parser.get_bool("pad_k");
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
void gemm_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
{
if(verify == 1)
{
c_m_n_host_result.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_result);
}
else if(verify == 2)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
c_m_n_host_result.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
}
bool structured_sparsity = arg_parser.get_bool("structured_sparsity");
return GemmDispatcher::dispatch(structured_sparsity, trait);
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,262 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <fstream>
#include <iomanip>
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "benchmark_gemm.hpp"
class GemmProfiler
{
public:
static GemmProfiler& instance(Setting setting)
{
static GemmProfiler instance{setting};
return instance;
}
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const CLayout layout_c = CLayout{};
gemm_problem.stride_a_ = ck_tile::get_default_stride(
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a));
gemm_problem.stride_b_ = ck_tile::get_default_stride(
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b));
gemm_problem.stride_c_ = ck_tile::get_default_stride(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
if(setting_.init_method_ == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
}
else if(setting_.init_method_ == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
}
else if(setting_.init_method_ == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
}
if(gemm_problem.structured_sparsity_)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs gemm_args;
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
gemm_args.k_batch = gemm_problem.split_k_;
gemm_args.M = gemm_problem.m_;
gemm_args.N = gemm_problem.n_;
gemm_args.K = gemm_problem.k_;
gemm_args.stride_A = gemm_problem.stride_a_;
gemm_args.stride_B = gemm_problem.stride_b_;
gemm_args.stride_C = gemm_problem.stride_c_;
ck_tile::HostTensor<CDataType> c_m_n_host_result(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
if(setting_.verify_)
{
gemm_host_reference(setting_.verify_,
a_m_k,
b_k_n,
c_m_n_host_result,
a_m_k_dev_buf,
b_k_n_dev_buf,
gemm_problem.m_,
gemm_problem.n_,
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
gemm_problem.stride_c_);
}
for(auto& callable : callables)
{
auto kernel_run_result = callable(gemm_args,
ck_tile::stream_config{nullptr,
true,
setting_.log_,
setting_.n_warmup_,
setting_.n_repeat_,
setting_.is_gpu_timer_,
setting_.flush_cache_,
setting_.rotating_count_});
process_result(gemm_problem,
c_m_n_dev_buf,
c_m_n_host_result,
c_m_n_dev_result,
kernel_run_result);
}
}
void process_result(const GemmProblem& gemm_problem,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
// compute performance metric
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
// update
kernel_instance.perf_result_.latency_ = avg_time;
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
if(setting_.log_ > 0)
{
std::cout << kernel_instance << std::endl;
}
// verify result
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ ||
compare(gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
if(verified_correct)
{
kernel_instances_.emplace_back(kernel_instance);
}
else
{
std::cout << "Verification failed, skip kernel: " << name << std::endl;
}
// clear tensor
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
}
KernelInstance select_best_instance(Metric metric)
{
if(kernel_instances_.empty())
throw std::runtime_error("Empty instances");
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
kernel_instances_.end(),
[metric](const auto& a, const auto& b) {
return PerformanceResult::compare(
b.perf_result_, a.perf_result_, metric);
});
std::cout << "**********************************" << std::endl;
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
<< "The best kernel instance is: " << kernel_instance << std::endl;
std::cout << "**********************************" << std::endl;
if(!setting_.csv_filename_.empty())
{
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
if(!file.is_open())
{
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
}
else
{
if(file.tellp() == 0)
{
file << "rocm_version,device_name,"
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
<< "dtype_a,dtype_b,dtype_acc,dtype_c,"
<< "layout_a,layout_b,layout_c,"
<< "structured_sparsity,"
<< "name,"
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
}
const auto& problem = kernel_instance.problem_;
const auto& name = kernel_instance.name_;
const auto& perf = kernel_instance.perf_result_;
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
<< "," << problem.structured_sparsity_ << "," << name << "," << std::fixed
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
<< "\n";
if(!file)
{
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
}
}
}
return kernel_instance;
}
GemmProfiler(const GemmProfiler&) = delete;
GemmProfiler& operator=(const GemmProfiler&) = delete;
private:
~GemmProfiler() { kernel_instances_.clear(); }
GemmProfiler(Setting setting) : setting_(setting) {}
Setting setting_;
std::vector<KernelInstance> kernel_instances_;
};

View File

@@ -0,0 +1,202 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Handles loading, parsing, and validation of JSON configuration parameters.
"""
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional, Union, Tuple, Type, Dict
import json
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]]
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int
max: int
step: int
exclude: Optional[List[int]]
def generate_candidates(self) -> List[int]:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(
f"Invalid range: min({self.min}) > max({self.max})"
)
if self.step <= 0:
raise ValueError(
f"Step must be positive, got {self.step}"
)
candidates = list(range(self.min, self.max + 1, self.step))
if hasattr(self, 'exclude') and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
candidates = [x for x in candidates if x not in exclude_set]
if not candidates:
raise ValueError(
f"No valid candidates for range [{self.min}-{self.max}] "
f"with step {self.step} and excludes {self.exclude}"
)
return candidates
@dataclass
class ProblemConfig:
"""configuration class for problem parameter."""
datatypes: Tuple[EnumConfigParam, ...]
layouts: Tuple[EnumConfigParam, ...]
@property
def datatype_map(self) -> Dict[str, str]:
"""Get datatype as a key-value map."""
return {
'matrix_a': self.datatypes[0].values[0],
'matrix_b': self.datatypes[1].values[0],
'matrix_c': self.datatypes[2].values[0]
}
@property
def layout_map(self) -> Dict[str, str]:
"""Get layout as a key-value map."""
return {
'matrix_a': self.layouts[0].values[0],
'matrix_b': self.layouts[1].values[0],
'matrix_c': self.layouts[2].values[0]
}
@dataclass
class TileConfig:
"""Configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
warp_m: Union[EnumConfigParam, RangeConfigParam]
warp_n: Union[EnumConfigParam, RangeConfigParam]
warp_k: Union[EnumConfigParam, RangeConfigParam]
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
@dataclass
class TraitConfig:
"""Configuration class for kernel traits."""
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
pad_m: EnumConfigParam
pad_n: EnumConfigParam
pad_k: EnumConfigParam
@dataclass
class GemmConfig:
"""Main configuration class for GEMM operations """
problem: ProblemConfig
tile_config: TileConfig
trait_config: TraitConfig
@classmethod
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
try:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
with config_path.open('r') as f:
config_dict = json.load(f)
# Parse problem config
problem = ProblemConfig(
datatypes=(
EnumConfigParam(
values=config_dict['problem']['datatype_a']['values']),
EnumConfigParam(
values=config_dict['problem']['datatype_b']['values']),
EnumConfigParam(
values=config_dict['problem']['datatype_c']['values'])
),
layouts=(
EnumConfigParam(
values=config_dict['problem']['layout_a']['values']),
EnumConfigParam(
values=config_dict['problem']['layout_b']['values']),
EnumConfigParam(
values=config_dict['problem']['layout_c']['values'])
)
)
# Parse tile config
def create_param(param_dict):
if 'values' in param_dict:
return EnumConfigParam(values=param_dict['values'])
else:
return RangeConfigParam(
min=param_dict['min'],
max=param_dict['max'],
step=param_dict['step'],
exclude=param_dict.get('exclude', [])
)
tile_config = TileConfig(
tile_m=create_param(config_dict['tile_config']['tile_m']),
tile_n=create_param(config_dict['tile_config']['tile_n']),
tile_k=create_param(config_dict['tile_config']['tile_k']),
warp_m=create_param(config_dict['tile_config']['warp_m']),
warp_n=create_param(config_dict['tile_config']['warp_n']),
warp_k=create_param(config_dict['tile_config']['warp_k']),
warp_tile_m=create_param(
config_dict['tile_config']['warp_tile_m']),
warp_tile_n=create_param(
config_dict['tile_config']['warp_tile_n']),
warp_tile_k=create_param(
config_dict['tile_config']['warp_tile_k'])
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict['trait_config']['pipeline']['values']),
scheduler=EnumConfigParam(
values=config_dict['trait_config']['scheduler']['values']),
epilogue=EnumConfigParam(
values=config_dict['trait_config']['epilogue']['values']),
pad_m=EnumConfigParam(
values=config_dict['trait_config']['pad_m']['values']),
pad_n=EnumConfigParam(
values=config_dict['trait_config']['pad_n']['values']),
pad_k=EnumConfigParam(
values=config_dict['trait_config']['pad_k']['values'])
)
return cls(
problem=problem,
tile_config=tile_config,
trait_config=trait_config
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}")
except KeyError as e:
raise KeyError(f"Missing required configuration field: {str(e)}")