diff --git a/CHANGELOG.md b/CHANGELOG.md index a1163f059c..d62a64f3e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. +* Added benchmarking support for tile engine GEMM. ### Optimized diff --git a/CMakeLists.txt b/CMakeLists.txt index a8bab1c971..e5e64055a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,10 @@ set(version 1.1.0) project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) +option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) +option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) + # Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8" # CK Codegen requires dataclass which is added in Python 3.7 # Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04 @@ -390,146 +394,152 @@ else() add_compile_definitions(__HIP_PLATFORM_HCC__=1) endif() -## tidy include(EnableCompilerWarnings) +## tidy set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") - set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) +set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) # Enable tidy on hip elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU") - set(CK_TIDY_ERRORS ALL) +set(CK_TIDY_ERRORS ALL) endif() -include(ClangTidy) -enable_clang_tidy( - CHECKS - * - -abseil-* - -android-cloexec-fopen - # Yea we shouldn't be using rand() - -cert-msc30-c - -bugprone-exception-escape - -bugprone-macro-parentheses - -cert-env33-c - -cert-msc32-c - -cert-msc50-cpp - -cert-msc51-cpp - -cert-dcl37-c - -cert-dcl51-cpp - -clang-analyzer-alpha.core.CastToStruct - -clang-analyzer-optin.performance.Padding - -clang-diagnostic-deprecated-declarations - -clang-diagnostic-extern-c-compat - -clang-diagnostic-unused-command-line-argument - -cppcoreguidelines-avoid-c-arrays - -cppcoreguidelines-avoid-magic-numbers - -cppcoreguidelines-explicit-virtual-functions - -cppcoreguidelines-init-variables - -cppcoreguidelines-macro-usage - -cppcoreguidelines-non-private-member-variables-in-classes - -cppcoreguidelines-pro-bounds-array-to-pointer-decay - -cppcoreguidelines-pro-bounds-constant-array-index - -cppcoreguidelines-pro-bounds-pointer-arithmetic - -cppcoreguidelines-pro-type-member-init - -cppcoreguidelines-pro-type-reinterpret-cast - -cppcoreguidelines-pro-type-union-access - -cppcoreguidelines-pro-type-vararg - -cppcoreguidelines-special-member-functions - -fuchsia-* - -google-explicit-constructor - -google-readability-braces-around-statements - -google-readability-todo - -google-runtime-int - -google-runtime-references - -hicpp-vararg - -hicpp-braces-around-statements - -hicpp-explicit-conversions - -hicpp-named-parameter - -hicpp-no-array-decay - # We really shouldn't use bitwise operators with signed integers, but - # opencl leaves us no choice - -hicpp-avoid-c-arrays - -hicpp-signed-bitwise - -hicpp-special-member-functions - -hicpp-uppercase-literal-suffix - -hicpp-use-auto - -hicpp-use-equals-default - -hicpp-use-override - -llvm-header-guard - -llvm-include-order - #-llvmlibc-* - -llvmlibc-restrict-system-libc-headers - -llvmlibc-callee-namespace - -llvmlibc-implementation-in-namespace - -llvm-else-after-return - -llvm-qualified-auto - -misc-misplaced-const - -misc-non-private-member-variables-in-classes - -misc-no-recursion - -modernize-avoid-bind - -modernize-avoid-c-arrays - -modernize-pass-by-value - -modernize-use-auto - -modernize-use-default-member-init - -modernize-use-equals-default - -modernize-use-trailing-return-type - -modernize-use-transparent-functors - -performance-unnecessary-value-param - -readability-braces-around-statements - -readability-else-after-return - # we are not ready to use it, but very useful - -readability-function-cognitive-complexity - -readability-isolate-declaration - -readability-magic-numbers - -readability-named-parameter - -readability-uppercase-literal-suffix - -readability-convert-member-functions-to-static - -readability-qualified-auto - -readability-redundant-string-init - # too many narrowing conversions in our code - -bugprone-narrowing-conversions - -cppcoreguidelines-narrowing-conversions - -altera-struct-pack-align - -cppcoreguidelines-prefer-member-initializer - ${CK_TIDY_CHECKS} - ${CK_TIDY_ERRORS} - HEADER_FILTER - "\.hpp$" - EXTRA_ARGS - -DCK_USE_CLANG_TIDY -) +if(ENABLE_CLANG_CPP_CHECKS) + include(ClangTidy) + enable_clang_tidy( + CHECKS + * + -abseil-* + -android-cloexec-fopen + # Yea we shouldn't be using rand() + -cert-msc30-c + -bugprone-exception-escape + -bugprone-macro-parentheses + -cert-env33-c + -cert-msc32-c + -cert-msc50-cpp + -cert-msc51-cpp + -cert-dcl37-c + -cert-dcl51-cpp + -clang-analyzer-alpha.core.CastToStruct + -clang-analyzer-optin.performance.Padding + -clang-diagnostic-deprecated-declarations + -clang-diagnostic-extern-c-compat + -clang-diagnostic-unused-command-line-argument + -cppcoreguidelines-avoid-c-arrays + -cppcoreguidelines-avoid-magic-numbers + -cppcoreguidelines-explicit-virtual-functions + -cppcoreguidelines-init-variables + -cppcoreguidelines-macro-usage + -cppcoreguidelines-non-private-member-variables-in-classes + -cppcoreguidelines-pro-bounds-array-to-pointer-decay + -cppcoreguidelines-pro-bounds-constant-array-index + -cppcoreguidelines-pro-bounds-pointer-arithmetic + -cppcoreguidelines-pro-type-member-init + -cppcoreguidelines-pro-type-reinterpret-cast + -cppcoreguidelines-pro-type-union-access + -cppcoreguidelines-pro-type-vararg + -cppcoreguidelines-special-member-functions + -fuchsia-* + -google-explicit-constructor + -google-readability-braces-around-statements + -google-readability-todo + -google-runtime-int + -google-runtime-references + -hicpp-vararg + -hicpp-braces-around-statements + -hicpp-explicit-conversions + -hicpp-named-parameter + -hicpp-no-array-decay + # We really shouldn't use bitwise operators with signed integers, but + # opencl leaves us no choice + -hicpp-avoid-c-arrays + -hicpp-signed-bitwise + -hicpp-special-member-functions + -hicpp-uppercase-literal-suffix + -hicpp-use-auto + -hicpp-use-equals-default + -hicpp-use-override + -llvm-header-guard + -llvm-include-order + #-llvmlibc-* + -llvmlibc-restrict-system-libc-headers + -llvmlibc-callee-namespace + -llvmlibc-implementation-in-namespace + -llvm-else-after-return + -llvm-qualified-auto + -misc-misplaced-const + -misc-non-private-member-variables-in-classes + -misc-no-recursion + -modernize-avoid-bind + -modernize-avoid-c-arrays + -modernize-pass-by-value + -modernize-use-auto + -modernize-use-default-member-init + -modernize-use-equals-default + -modernize-use-trailing-return-type + -modernize-use-transparent-functors + -performance-unnecessary-value-param + -readability-braces-around-statements + -readability-else-after-return + # we are not ready to use it, but very useful + -readability-function-cognitive-complexity + -readability-isolate-declaration + -readability-magic-numbers + -readability-named-parameter + -readability-uppercase-literal-suffix + -readability-convert-member-functions-to-static + -readability-qualified-auto + -readability-redundant-string-init + # too many narrowing conversions in our code + -bugprone-narrowing-conversions + -cppcoreguidelines-narrowing-conversions + -altera-struct-pack-align + -cppcoreguidelines-prefer-member-initializer + ${CK_TIDY_CHECKS} + ${CK_TIDY_ERRORS} + HEADER_FILTER + "\.hpp$" + EXTRA_ARGS + -DCK_USE_CLANG_TIDY + ) -include(CppCheck) -enable_cppcheck( - CHECKS - warning - style - performance - portability - SUPPRESS - ConfigurationNotChecked - constStatement - duplicateCondition - noExplicitConstructor - passedByValue - preprocessorErrorDirective - shadowVariable - unusedFunction - unusedPrivateFunction - unusedStructMember - unmatchedSuppression - FORCE - SOURCES - library/src - INCLUDE - ${CMAKE_CURRENT_SOURCE_DIR}/include - ${CMAKE_CURRENT_BINARY_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR}/library/include - DEFINE - CPPCHECK=1 - __linux__=1 -) + include(CppCheck) + enable_cppcheck( + CHECKS + warning + style + performance + portability + SUPPRESS + ConfigurationNotChecked + constStatement + duplicateCondition + noExplicitConstructor + passedByValue + preprocessorErrorDirective + shadowVariable + unusedFunction + unusedPrivateFunction + unusedStructMember + unmatchedSuppression + FORCE + SOURCES + library/src + INCLUDE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/library/include + DEFINE + CPPCHECK=1 + __linux__=1 + ) +else() + function(clang_tidy_check TARGET) + # stub out empty function if clang tidy is not enabled + endfunction() +endif() set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) @@ -557,12 +567,15 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS add_compile_options(-fdiagnostics-color=always) endif() -# make check runs the entire set of examples and tests -add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) -# make smoke runs the tests and examples that runs within 30 seconds on gfx90a -add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST") -# make regression runs the tests and examples that runs for more 30 seconds on gfx90a -add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST") +if(NOT MIOPEN_REQ_LIBS_ONLY) + # make check runs the entire set of examples and tests + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + # make smoke runs the tests and examples that runs within 30 seconds on gfx90a + add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST") + # make regression runs the tests and examples that runs for more 30 seconds on gfx90a + add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST") +endif() + file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") @@ -607,6 +620,7 @@ ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) add_subdirectory(library) diff --git a/Jenkinsfile b/Jenkinsfile index c26350f120..c4b5efe3bc 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -114,6 +114,9 @@ def check_arch(){ else if ( runShell('grep -n "gfx908" rocminfo.log') ) { arch_type = 6 } + else if ( runShell('grep -n "gfx950" rocminfo.log') ) { + arch_type = 7 + } return arch_type } @@ -132,6 +135,10 @@ def getDockerImage(Map conf=[:]){ image = conf.get("docker_name", "") echo "Using legacy docker: ${image}" } + else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){ + image = conf.get("docker_name", "") + echo "Using special docker: ${image}" + } else{ image = getDockerImageName() echo "Using default docker: ${image}" @@ -208,6 +215,11 @@ def cmake_build(Map conf=[:]){ def build_type_debug = (conf.get("build_type",'release') == 'debug') + // use special compiler for gfx950 + if ( check_arch() == 7){ + compiler = "/llvm-project/build/bin/clang++" + } + //cmake_env can overwrite default CXX variables. def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") @@ -263,6 +275,9 @@ def cmake_build(Map conf=[:]){ if (setup_args.contains("gfx94")){ invocation_tag="gfx94" } + if (setup_args.contains("gfx95")){ + invocation_tag="gfx95" + } echo "invocation tag: ${invocation_tag}" def redis_pre_setup_cmd = pre_setup_cmd if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") { @@ -422,16 +437,6 @@ def buildHipClangJob(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm - - def image - if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){ - image = conf.get("docker_name", "") - echo "Using legacy docker: ${image}" - } - else{ - image = getDockerImageName() - echo "Using default docker: ${image}" - } def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group @@ -455,7 +460,7 @@ def buildHipClangJob(Map conf=[:]){ echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME - + def image def retimage (retimage, image) = getDockerImage(conf) @@ -496,17 +501,6 @@ def Build_CK(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 env.DOCKER_BUILDKIT=1 checkout scm - - def image - if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){ - image = conf.get("docker_name", "") - echo "Using legacy docker: ${image}" - } - else{ - image = getDockerImageName() - echo "Using default docker: ${image}" - } - def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group @@ -527,6 +521,7 @@ def Build_CK(Map conf=[:]){ echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME + def image def retimage gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { @@ -638,6 +633,13 @@ def Build_CK(Map conf=[:]){ archiveArtifacts "perf_onnx_gemm_gfx908.log" stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908" } + else if ( arch == 7 ){ + // run basic tests on gfx950 + echo "Run performance tests" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950" + archiveArtifacts "perf_onnx_gemm_gfx950.log" + stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950" + } } } if (params.hipTensor_test && arch == 1 ){ @@ -774,8 +776,8 @@ def process_results(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true @@ -848,8 +850,8 @@ pipeline { description: "Run the grouped conv large cases tests (default: OFF)") booleanParam( name: "RUN_CODEGEN_TESTS", - defaultValue: false, - description: "Run codegen tests (default: OFF)") + defaultValue: true, + description: "Run codegen tests (default: ON)") booleanParam( name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, @@ -862,6 +864,10 @@ pipeline { name: "RUN_CK_TILE_GEMM_TESTS", defaultValue: false, description: "Run the ck_tile GEMM tests (default: OFF)") + booleanParam( + name: "RUN_TILE_ENGINE_GEMM_TESTS", + defaultValue: false, + description: "Run the tile_engine_gemm tests (default: OFF)") booleanParam( name: "BUILD_INSTANCES_ONLY", defaultValue: false, @@ -870,6 +876,10 @@ pipeline { name: "BUILD_GFX908", defaultValue: false, description: "Build CK and run tests on gfx908 (default: OFF)") + booleanParam( + name: "BUILD_GFX950", + defaultValue: false, + description: "Build CK and run tests on gfx950 (default: OFF)") booleanParam( name: "BUILD_GFX12", defaultValue: true, @@ -1145,6 +1155,48 @@ pipeline { } } } + stage("Run TILE_ENGINE_GEMM Tests") + { + parallel + { + stage("Run TILE_ENGINE_GEMM Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make benchmark_gemm -j && \ + ./bin/benchmark_gemm """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run TILE_ENGINE_GEMM Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make benchmark_gemm -j && \ + ./bin/benchmark_gemm """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Build CK and run Tests") { @@ -1188,7 +1240,7 @@ pipeline { cleanWs() } } - stage("Build CK for all gfx9 targets") + stage("Build CK and run Tests on gfx942") { when { beforeAgent true @@ -1203,6 +1255,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1210,6 +1263,29 @@ pipeline { cleanWs() } } + stage("Build CK and run Tests on gfx950") + { + when { + beforeAgent true + expression { params.BUILD_GFX950.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + } + agent{ label rocmnode("gfx950") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx950" \ + -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx950" \ + -DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, docker_name: "rocm/composable_kernel-private:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } stage("Build CK and run Tests on gfx908") { when { @@ -1223,6 +1299,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx908" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1243,6 +1320,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx90a" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1250,7 +1328,7 @@ pipeline { cleanWs() } } - stage("Build CK instances for different targets") + stage("Build CK instances for all supported targets") { when { beforeAgent true @@ -1281,6 +1359,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx1030" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1301,6 +1380,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx1101" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1321,6 +1401,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx1201" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 9e7c360f54..8ddc663452 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -48,6 +48,7 @@ rocm_install_targets( INCLUDE include ) rocm_export_targets( + TARGETS ck_host ck_headers EXPORT ck_host_targets NAMESPACE composable_kernel:: ) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 6c48b2de09..725a745f3a 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core[api_reference]==1.18.4 +rocm-docs-core[api_reference]==1.20.0 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 62c3ea8ff8..f74ad725af 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -237,7 +237,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core[api-reference]==1.18.4 +rocm-docs-core[api-reference]==1.20.0 # via -r requirements.in rpds-py==0.24.0 # via diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 76b9429b2e..0f5670f1b9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -3,7 +3,7 @@ # generate kernel instances to speed up compilation import copy -from dataclasses import dataclass +from dataclasses import dataclass, field import fnmatch import itertools from pathlib import Path @@ -117,8 +117,50 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp" FMHA_FWD_API=""" -float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{ +#include + +namespace {{ +bool get_num_cus(unsigned& num_cu) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cu = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{ float r = -1; + + const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + {F_dispatch} return r; }} @@ -134,36 +176,50 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_batch_prefill_(s, a); }} """ +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + @dataclass class FmhaFwdApiTrait: pipeline_tag : str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - logits : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + constraint : CppConstraint @property def name(self) -> str: @@ -220,17 +276,18 @@ class FmhaFwdApiTrait: class FmhaFwdPipeline: tag : str - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_logits : str # t/f - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: @@ -297,8 +354,8 @@ class FmhaFwdApiPool: inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) @@ -313,25 +370,27 @@ class FmhaFwdApiPool: @dataclass class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ @@ -423,33 +482,21 @@ class FmhaFwdKernel: spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) + dvpad=self.F_pipeline.F_dvpad, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - ### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - ### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - ### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - ### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - } - else: - return None +class KernelComponentFactory: + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '128' : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: - # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad - # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + @staticmethod + def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! @@ -458,53 +505,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl pipelines = [] if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - if hdim == 256: - # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - else: - if bias == "bias": - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: - # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask)) - elif dtype in ['fp8fp16', 'fp8bf16']: - # TODO - None + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: assert False return pipelines +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '128' : [FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),] + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + gen = list() api_pool = FmhaFwdApiPool(mask_impl) for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) + d = CustomFactory.get_hdim_tile_size_dict(dtype) if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] + tiles = d[hdim_str] hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 2f1287c87a..7cbbdb9034 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -58,7 +58,8 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_lse}, {F_dropout}, {F_squant}, - {F_occupancy}>; + {F_occupancy}, + {F_skip}>; using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; @@ -94,7 +95,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; #include @@ -129,9 +130,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; return fmha_fwd_(s, a); }} """ @@ -160,11 +161,12 @@ class FmhaFwdApiTrait: skpad : str dpad : str dvpad : str + skip : str @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' @property def scheck(self) -> str: @@ -227,6 +229,7 @@ class FmhaFwdPipeline: F_dropout : str # F_squant : str # F_mask : str # value from MASK_MAP + F_skip : str # true/false @property def name(self) -> str: @@ -262,8 +265,12 @@ class FmhaFwdPipeline: if self.F_dropout == 't' : n += '_dropout' else: n += '_ndropout' + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' + return n class FmhaFwdApiPool: @@ -293,7 +300,7 @@ class FmhaFwdApiPool: inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, @@ -381,6 +388,7 @@ class FmhaFwdKernel: F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], @@ -419,7 +427,8 @@ class FmhaFwdKernel: spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip) # TODO: design a more practical way to do it # this is current supported tile size per hdim @@ -453,36 +462,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -508,7 +517,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue if hdim == 192 and tile.F_bn1 == 128: # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': + if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): @@ -532,6 +541,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' if not cond: continue # PyTorch integration @@ -540,6 +550,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'bias'] cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' if not cond: continue # Aiter(mha_fwd) integration @@ -565,6 +576,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_squant == 'f' if not cond: continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 1838ee5bd9..5ce56d48b5 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -169,6 +169,7 @@ struct fmha_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; float p_drop; bool s_randval; @@ -433,6 +434,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.window_size_left, args.window_size_right, args.mask_type, + args.min_seqlen_q, args.p_drop, args.s_randval, args.drop_seed_offset); @@ -837,7 +839,8 @@ template + bool kPadDv_, + bool kSkipMinSeqlenQ_ = false> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -861,6 +864,7 @@ struct fmha_fwd_traits_ static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; template @@ -995,6 +999,7 @@ struct fmha_fwd_traits bool has_lse; bool has_dropout; bool do_fp8_static_quant; + bool skip_min_seqlen_q = false; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 1edb3da947..386fe93715 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -214,4 +214,15 @@ int run_gemm_example(int argc, char* argv[]) } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + try + { + return !run_gemm_example(argc, argv); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 25fab6bde0..4c9fecaba6 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -220,4 +220,11 @@ auto create_args(int argc, char* argv[]) } // host API +template float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 79ed9ce76b..3010130e6c 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -178,7 +178,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index b60a3b274b..5dcb685839 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -11,6 +11,7 @@ #include "ck_tile/host.hpp" #include "gemm_utils.hpp" +#include "run_gemm_example.inc" template void try_run(ck_tile::TailNumber tn) @@ -74,64 +75,102 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) @@ -243,8 +282,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& return ave_time; } -#include "run_gemm_example.inc" - template int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) { @@ -345,7 +382,7 @@ int main(int argc, char* argv[]) { try { - run_gemm_example(argc, argv); + return !run_gemm_example(argc, argv); } catch(const std::runtime_error& e) { diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index ce689a370c..da1c15b86f 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -334,16 +334,26 @@ bool test_moe_sorting(ck_tile::ArgParser args) int main(int argc, char** argv) { - auto [result, args] = create_args(argc, argv); - if(!result) - return -1; - std::string index_prec = args.get_str("pr_i"); - std::string weight_prec = args.get_str("pr_w"); - - bool r = true; - if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0) + try { - r &= test_moe_sorting(args); + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + + std::string index_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + + bool r = true; + if(weight_prec == "fp32" && index_prec == "int32") + { + r &= test_moe_sorting(args); + } + + return r ? 0 : -1; + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; } - return r ? 0 : -1; } diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index 0219c67305..68ad1106ce 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -320,4 +320,15 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + try + { + return !run_batched_gemm_example(argc, argv); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 61193e2e29..067319b3f9 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -319,4 +319,15 @@ float grouped_gemm(const std::vector& gemm_descs, #include "run_grouped_gemm_example.inc" constexpr bool Persistent = false; -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + try + { + return !run_grouped_gemm_example(argc, argv); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 5f2c2a5aab..2dbff1bc5c 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -11,6 +11,7 @@ #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" +#include "run_flatmm_example.inc" template (Kernel{}, grids, blocks, 0, kargs)); + float ave_time{0}; + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } return ave_time; }; if(args.k_batch == 1) @@ -132,8 +171,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con } } -#include "run_flatmm_example.inc" - int run_flatmm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -177,4 +214,15 @@ int run_flatmm_example(int argc, char* argv[]) return -1; } -int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + try + { + return !run_flatmm_example(argc, argv); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index bbce978724..55f2d4f367 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -133,4 +133,11 @@ auto create_args(int argc, char* argv[]) } // host API +template float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index c191fff7d0..3d4f154af7 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -122,7 +122,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, float ave_time = flatmm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 57c4b1a5cf..33b6d7c585 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -393,8 +393,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { const index_t GemmM = K; const index_t GemmN = C * X; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; return transform_tensor_descriptor( wei_grid_desc, @@ -432,8 +434,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { const index_t GemmM = K; const index_t GemmN = C * X * Y; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; return transform_tensor_descriptor( wei_grid_desc, @@ -472,8 +476,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle { const index_t GemmM = K; const index_t GemmN = C * X * Y * Z; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; return transform_tensor_descriptor( wei_grid_desc, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 0831b754c8..e9e02eae81 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -208,8 +208,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle const index_t GemmM = K; const index_t GemmN = C * Z * X * Y; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock; diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 0e58d5acb4..badd64508d 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -351,6 +351,98 @@ struct Bilinear float beta_; }; +struct AddClamp +{ + AddClamp(float floor = 0.f, float ceil = NumericLimits::Max()) + : floor_(floor), ceil_(ceil){}; + + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = a > type_convert(floor_) + ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) + : type_convert(floor_); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + const float a = x0 + x1; + y = a > type_convert(floor_) + ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) + : type_convert(floor_); + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + const float a = x0 + type_convert(x1); + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float a = x0 + type_convert(x1); + y = a > type_convert(floor_) + ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) + : type_convert(floor_); + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float a = type_convert(x0) + type_convert(x1); + y = a > type_convert(floor_) + ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) + : type_convert(floor_); + }; + + template <> + __host__ __device__ constexpr void + operator()(int& y, const int& x0, const int8_t& x1) const + { + const int8_t a = x0 + x1; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + const int8_t a = x0 + x1; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + const float floor_; + const float ceil_; +}; + struct AddRelu { template diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp index c11bf845d0..bd3ab10802 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -166,8 +166,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmM = K; const index_t GemmN = C * X; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = @@ -365,8 +365,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmM = K; const index_t GemmN = C * X * Y; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = @@ -558,8 +558,8 @@ struct TransformConvBwdWeightToGemm const index_t GemmM = K; const index_t GemmN = C * Z * X * Y; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp index f34e0e59b3..b72ddb8243 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -346,8 +346,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmM = K * NumGroupsToMerge; const index_t GemmN = C * X * NumGroupsToMerge; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = @@ -534,8 +534,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmM = K * NumGroupsToMerge; const index_t GemmN = C * X * Y * NumGroupsToMerge; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = @@ -737,8 +737,8 @@ struct TransformConvBwdWeightToGemmV2 const index_t GemmM = K * NumGroupsToMerge; const index_t GemmN = C * Z * X * Y * NumGroupsToMerge; - const auto PadGemmM = MPerBlock - GemmM % MPerBlock; - const auto PadGemmN = NPerBlock - GemmN % NPerBlock; + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; const index_t GemmKBatch = batch_k; const index_t GemmK0 = diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 0576162943..ae6a212725 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -11,6 +11,7 @@ /// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h. #if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) +#define CHAR_BIT 8 using int8_t = signed char; using uint8_t = unsigned char; using int16_t = signed short; diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 27af59c192..be84842347 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -55,8 +55,8 @@ #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_scatter_gather.hpp" -#include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/tile_window_base.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/transpose_tile.hpp" diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 5d6d6ce348..68648e1c02 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -1437,8 +1437,8 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe static_assert( (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || @@ -1561,6 +1561,24 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe return bit_cast(tmp); } + else if constexpr(N == 16) + { + thread_buffer tmp; + + tmp.template get_as()(number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.template get_as()(number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + return bit_cast(tmp); + } } else if constexpr(std::is_same::value) // bf16 { @@ -1597,6 +1615,24 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe return bit_cast(tmp); } + else if constexpr(N == 16) + { + thread_buffer tmp; + + tmp.template get_as()(number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.template get_as()(number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + return bit_cast(tmp); + } } else // other datatype { diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 5a5e01460f..3459e728e0 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -35,4 +35,7 @@ #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/host/stream_utils.hpp" #include "ck_tile/host/timer.hpp" +#include "ck_tile/host/flush_icache.hpp" +#include "ck_tile/host/rotating_buffers.hpp" diff --git a/include/ck_tile/host/device_prop.hpp b/include/ck_tile/host/device_prop.hpp new file mode 100644 index 0000000000..d33b298369 --- /dev/null +++ b/include/ck_tile/host/device_prop.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef __HIPCC_RTC__ +#include +#include +#include + +namespace ck_tile { + +constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u) +{ + return str.empty() ? h + : fnv1a_hash(str.substr(1), + (h ^ static_cast(str.front())) * 16777619u); +} +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string raw_name(props.gcnArchName); + const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. + switch(fnv1a_hash(name)) + { + // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + case fnv1a_hash("Ellesmere"): + case fnv1a_hash("Baffin"): + case fnv1a_hash("RacerX"): + case fnv1a_hash("Polaris10"): + case fnv1a_hash("Polaris11"): + case fnv1a_hash("Tonga"): + case fnv1a_hash("Fiji"): + case fnv1a_hash("gfx800"): + case fnv1a_hash("gfx802"): + case fnv1a_hash("gfx804"): return "gfx803"; + case fnv1a_hash("Vega10"): + case fnv1a_hash("gfx901"): return "gfx900"; + case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030"; + default: return name; + } +} +} // namespace ck_tile + +#endif diff --git a/include/ck_tile/host/flush_icache.hpp b/include/ck_tile/host/flush_icache.hpp new file mode 100644 index 0000000000..9230b50a13 --- /dev/null +++ b/include/ck_tile/host/flush_icache.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { +static __global__ void flush_cache() +{ + asm __volatile__("s_icache_inv \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" :: + :); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/kernel_launch.hpp b/include/ck_tile/host/kernel_launch.hpp index d159787387..9770e99738 100644 --- a/include/ck_tile/host/kernel_launch.hpp +++ b/include/ck_tile/host/kernel_launch.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,6 +11,13 @@ #include namespace ck_tile { + +#define LOW_CU_PROCESSORS 80 +#define HIGH_CU_PROCESSORS 228 +#define OPTIMAL_LATENCY_LOW_CU_PROCESSORS 0.005 +#define OPTIMAL_LATENCY_HIGH_CU_PROCESSORS 0.0015 +#define OPTIMAL_LATENCY_SAFE_MARGIN 0.01 + template #if CK_TILE_USE_LAUNCH_BOUNDS __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) @@ -81,6 +88,8 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla template CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callables) { + static_assert(sizeof...(callables) > 0, "At least one callable is required!"); + if(!s.time_kernel_) { launch_and_check(s, std::forward(callables)...); @@ -88,7 +97,7 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable } auto time_launches = [&](auto timer) { - // warmup + // Warmup for(int i = 0; i < s.cold_niters_; i++) { launch_and_check(s, std::forward(callables)...); @@ -114,4 +123,53 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable } } +template +CK_TILE_HOST float launch_kernel_preprocess(const stream_config& s, + PreprocessFunc preprocess, + Callables&&... callables) +{ + static_assert(sizeof...(callables) > 0, "At least one callable is required!"); + + if(!s.time_kernel_) + { + preprocess(); + launch_and_check(s, std::forward(callables)...); + return 0; + } + + auto time_launches = [&](auto timer) { + // Warmup + for(int i = 0; i < s.cold_niters_; i++) + { + launch_and_check(s, std::forward(callables)...); + } + + timer.start(s.stream_id_); + for(int i = 0; i < s.nrepeat_; i++) + { + preprocess(); + launch_and_check(s, std::forward(callables)...); + } + timer.stop(s.stream_id_); + + hipDeviceProp_t deviceProps; + HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0)); + + float preprocess_offset = (deviceProps.multiProcessorCount >= HIGH_CU_PROCESSORS) + ? OPTIMAL_LATENCY_HIGH_CU_PROCESSORS + : (deviceProps.multiProcessorCount == LOW_CU_PROCESSORS) + ? OPTIMAL_LATENCY_LOW_CU_PROCESSORS + : OPTIMAL_LATENCY_SAFE_MARGIN; + return (timer.duration() - preprocess_offset * s.nrepeat_) / s.nrepeat_; + }; + + if(s.is_gpu_timer_) + { + return time_launches(gpu_timer{}); + } + else + { + return time_launches(cpu_timer{}); + } +} } // namespace ck_tile diff --git a/include/ck_tile/host/rotating_buffers.hpp b/include/ck_tile/host/rotating_buffers.hpp new file mode 100644 index 0000000000..86f68ad084 --- /dev/null +++ b/include/ck_tile/host/rotating_buffers.hpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include + +namespace ck_tile { + +template +struct RotatingMemWrapper +{ + RotatingMemWrapper() = delete; + RotatingMemWrapper(const void* a_ptr_, + const void* b_ptr_, + std::size_t rotating_count_, + std::size_t size_a_, + std::size_t size_b_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + rotating_count(rotating_count_), + size_a(size_a_), + size_b(size_b_) + { + p_a_grids.push_back(a_ptr); + p_b_grids.push_back(b_ptr); + for(size_t i = 1; i < rotating_count; i++) + { + { + void* pADeviceBuf; + HIP_CHECK_ERROR(hipMalloc(static_cast(&pADeviceBuf), size_a_)); + HIP_CHECK_ERROR(hipMemcpy(static_cast(pADeviceBuf), + const_cast(p_a_grids[0]), + size_a_, + hipMemcpyDeviceToDevice)); + p_a_grids.push_back(pADeviceBuf); + } + + { + void* pBDeviceBuf; + HIP_CHECK_ERROR(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); + HIP_CHECK_ERROR(hipMemcpy(static_cast(pBDeviceBuf), + const_cast(p_b_grids[0]), + size_b_, + hipMemcpyDeviceToDevice)); + p_b_grids.push_back(pBDeviceBuf); + } + } + } + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; + a_ptr = p_a_grids[idx]; + b_ptr = p_b_grids[idx]; + } + } + void Print() + { + std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b + << ", rotating_count: " << rotating_count << "}" << std::endl; + } + ~RotatingMemWrapper() noexcept + { + if(rotating_count > 1) + { + // restore ptr + a_ptr = p_a_grids[0]; + b_ptr = p_b_grids[0]; + + // free device mem + for(size_t i = 1; i < rotating_count; i++) + { + ck_tile::hip_check_error(hipFree(const_cast(p_a_grids[i]))); + ck_tile::hip_check_error(hipFree(const_cast(p_b_grids[i]))); + } + } + } + + private: + const void* a_ptr; + const void* b_ptr; + std::size_t iter = 0; + std::size_t rotating_count = 1; + std::size_t size_a = 0; + std::size_t size_b = 0; + std::vector p_a_grids; + std::vector p_b_grids; +}; +inline void flush_icache() +{ + hipDeviceProp_t deviceProps; + HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0)); + int32_t gpu_block3 = deviceProps.multiProcessorCount * 60; + + ck_tile::flush_cache<<>>(); + HIP_CHECK_ERROR(hipGetLastError()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/stream_config.hpp b/include/ck_tile/host/stream_config.hpp index 47cf0fd5e4..f6bd40f6f2 100644 --- a/include/ck_tile/host/stream_config.hpp +++ b/include/ck_tile/host/stream_config.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -30,5 +30,7 @@ struct stream_config int cold_niters_ = 3; int nrepeat_ = 10; bool is_gpu_timer_ = true; // keep compatible + bool flush_cache_ = false; + int rotating_count_ = 1; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index a4b3765455..ac37f5dd06 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -53,6 +53,8 @@ struct FmhaFwdKernel static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ; + using AttentionVariant = ck_tile::remove_cvref_t; using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; @@ -257,6 +259,11 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval = 0; }; + struct FmhaFwdSkipMinSeqlenQKargs + { + ck_tile::index_t min_seqlen_q = 0; + }; + struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -664,6 +672,7 @@ struct FmhaFwdKernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, + ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant, std::pair> @@ -698,6 +707,7 @@ struct FmhaFwdKernel {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout {}, // placeholder for logits_soft_cap + {}, // placeholder for min_seqlen_q reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -753,6 +763,10 @@ struct FmhaFwdKernel { kargs.init_logits_soft_cap(logits_soft_cap); } + if constexpr(kSkipMinSeqlenQ) + { + kargs.min_seqlen_q = min_seqlen_q; + } return kargs; } @@ -969,7 +983,15 @@ struct FmhaFwdKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } else { @@ -989,7 +1011,15 @@ struct FmhaFwdKernel const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } } } @@ -1053,6 +1083,14 @@ struct FmhaFwdKernel const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + if constexpr(kSkipMinSeqlenQ) + { + if(kargs.seqlen_q <= kargs.min_seqlen_q) + { + return; + } + } + // # of required blocks is different in each groups, terminate unnecessary blocks // earlier if(kargs.seqlen_q <= i_m0) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 63011d2ba9..501aa26667 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -561,7 +561,16 @@ struct FmhaFwdSplitKVKernel const index_t i_nhead = blockIdx.y; const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + if constexpr(kHasMask) + { + // assume that num_tile_n1 is always 1 + return ck_tile::make_tuple( + (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } + else + { + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } } __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index f35c00c268..21cc4950eb 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -53,6 +53,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; + static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kHasDropout = Traits::kHasDropout; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 4530b58d85..442619a3dc 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -19,7 +19,8 @@ template + index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */ + bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */> struct TileFmhaTraits { static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; @@ -33,6 +34,7 @@ struct TileFmhaTraits static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; template ; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index df24b4cbcb..7311f4bf75 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp index 6bb6d255f3..5a4d0338b0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp @@ -25,7 +25,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp index 195367ffd7..6da3ee1a4f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp @@ -33,7 +33,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 182c785978..d074988a22 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -25,7 +25,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddRelu = ck::tensor_operation::element_wise::AddRelu; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp similarity index 69% rename from library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp rename to library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index d873edadba..39231e31f0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -13,7 +13,7 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" #ifdef CK_USE_XDL -#include "grouped_convolution_forward_bias_relu_xdl.inc" +#include "grouped_convolution_forward_bias_clamp_xdl.inc" #endif namespace ck { @@ -44,7 +44,7 @@ struct DeviceOperationInstanceFactory> { @@ -60,7 +60,7 @@ struct DeviceOperationInstanceFactory; @@ -80,23 +80,23 @@ struct DeviceOperationInstanceFactory && is_same_v) { - add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( op_ptrs); - add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( op_ptrs); - add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( op_ptrs); - add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( op_ptrs); - add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( op_ptrs); - add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( op_ptrs); - add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( op_ptrs); - add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( op_ptrs); - add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( op_ptrs); } #endif @@ -112,19 +112,19 @@ struct DeviceOperationInstanceFactory && is_same_v) { - add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); - add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( op_ptrs); - add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); - add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); - add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( op_ptrs); - add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( op_ptrs); - add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( op_ptrs); } #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc similarity index 88% rename from library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc rename to library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc index 1935f123a8..cc29e66cc1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc @@ -10,7 +10,7 @@ namespace instance { #ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( std::vector>>& instances); + AddClamp>>>& instances); -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( std::vector>>& instances); + AddClamp>>>& instances); #endif diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index aef40b8cb3..67ce4e39e1 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -104,17 +104,6 @@ function(add_instance_library INSTANCE_NAME) endif() endforeach() - if(MIOPEN_REQ_LIBS_ONLY) - message("Removing all sources that are not required for MIOpen") - foreach(source IN LISTS ARGN) - if(source MATCHES "gemm" OR - source MATCHES "mha" OR - source MATCHES "contraction" OR - source MATCHES "reduce") - list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() - endif() #message("remaining instances: ${ARGN}") #only continue if there are some source files left on the list if(ARGN) @@ -180,7 +169,7 @@ function(add_instance_library INSTANCE_NAME) target_compile_features(${INSTANCE_NAME} PUBLIC) # flags to compress the library - if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) + if(NOT DISABLE_OFFLOAD_COMPRESS AND NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) #message("Adding --offload-compress flag for ${INSTANCE_NAME}") target_compile_options(${INSTANCE_NAME} PRIVATE --offload-compress) endif() @@ -294,6 +283,17 @@ FOREACH(subdir_path ${dir_list}) message("Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") set(add_inst 0) endif() + + if(MIOPEN_REQ_LIBS_ONLY) + message("Removing all sources that are not required for MIOpen") + if("${cmake_instance}" MATCHES "gemm" OR + "${cmake_instance}" MATCHES "mha" OR + "${cmake_instance}" MATCHES "contraction" OR + "${cmake_instance}" MATCHES "reduce") + set(add_inst 0) + endif() + endif() + if((add_inst EQUAL 1)) get_filename_component(target_dir ${subdir_path} NAME) add_subdirectory(${target_dir}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt new file mode 100644 index 0000000000..b0a0cbb293 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance + xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp + + xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + + xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + + xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp + + xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp index 75acd604ee..1dfb7577f7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp @@ -10,7 +10,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( std::vector>>& instances) + AddClamp>>>& instances) { if(ck::get_device_name() == "gfx950") { @@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_ NHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, @@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_ NHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, @@ -57,7 +57,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_ NHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp index 69a8a4bd9d..171efd60da 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -10,7 +10,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_comp_instances<2, @@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins NHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_comp_instances<2, @@ -42,7 +42,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins NHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_comp_instances<2, @@ -52,7 +52,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_ins NHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp index 043c724e4a..49263b43eb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -10,7 +10,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( std::vector>>& instances) + AddClamp>>>& instances) { if(ck::get_device_name() != "gfx950") { @@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par NHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, @@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par NHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, @@ -57,7 +57,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_par NHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp index c58631e169..b418807bdf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, @@ -31,7 +31,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in NHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, @@ -41,7 +41,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in NHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, @@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_in NHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index cd80f2875f..6c666706a7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_instances<2, @@ -31,7 +31,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance NHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_instances<2, @@ -41,7 +41,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance NHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_instances<2, @@ -51,7 +51,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance NHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp similarity index 93% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index a6286b55e8..cd679f4b2d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances( instances, @@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_ NHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp index 0736325b05..f0638a96f5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<2, @@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte ConvFwdDefault, Interwave, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<2, @@ -43,7 +43,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte ConvFwd1x1P0, Interwave, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<2, @@ -54,7 +54,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inte ConvFwd1x1S1P0, Interwave, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp index 0d35ab1b05..6d07172806 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<2, @@ -32,7 +32,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr ConvFwdDefault, Intrawave, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<2, @@ -43,7 +43,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr ConvFwd1x1P0, Intrawave, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<2, @@ -54,7 +54,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intr ConvFwd1x1S1P0, Intrawave, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 253e8b196e..2c576431e3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -10,7 +10,7 @@ namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances) + AddClamp>>>& instances) { if(ck::get_device_name() == "gfx950") { @@ -35,7 +35,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk NHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, @@ -46,7 +46,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk NHWGK, ConvFwd3x3, Tuple, - AddRelu>{}); + AddClamp>{}); } else { @@ -59,7 +59,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk NHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, @@ -70,7 +70,7 @@ void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk NHWGK, ConvFwd3x3, Tuple, - AddRelu>{}); + AddClamp>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt deleted file mode 100644 index 98b0b1c4cb..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -# ONLY XDL_KERNELS -add_instance_library(device_grouped_conv2d_fwd_bias_relu_instance - xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp - xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp - - xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp - - xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp - - xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp - xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp - - xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp - xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp - xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp -) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt new file mode 100644 index 0000000000..a1c3feed3b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp + + xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp + + xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +) + +add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp similarity index 96% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp index 9819f0ea0b..5130312db2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp @@ -10,7 +10,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_comp_instances<3, @@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_comp_instances<3, NDHWGC, @@ -41,7 +41,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_comp_instances<3, NDHWGC, @@ -50,7 +50,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); if(ck::get_device_name() != "gfx950") { @@ -63,7 +63,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, @@ -73,7 +73,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, @@ -83,7 +83,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } if(ck::get_device_name() == "gfx950") @@ -97,7 +97,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, @@ -107,7 +107,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, @@ -117,7 +117,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_ NDHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp index dc3fc7a4bf..86dad21d43 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, @@ -31,7 +31,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16 NDHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, NDHWGC, @@ -40,7 +40,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16 NDHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, NDHWGC, @@ -49,7 +49,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16 NDHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index a9a8ff8459..685a729c3a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_instances<3, @@ -31,7 +31,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta NDHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_instances<3, NDHWGC, @@ -40,7 +40,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta NDHWGK, ConvFwd1x1P0, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_instances<3, NDHWGC, @@ -49,7 +49,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_insta NDHWGK, ConvFwd1x1S1P0, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 93% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index e58e879973..b553d007af 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances( instances, @@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhw NDHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp index e76052c6e0..7d892855ec 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<3, @@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i ConvFwdDefault, Interwave, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<3, NDHWGC, @@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i ConvFwd1x1P0, Interwave, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<3, NDHWGC, @@ -52,7 +52,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i ConvFwd1x1S1P0, Interwave, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp index 0593f3f46a..a2d0c6a2e1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<3, @@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i ConvFwdDefault, Intrawave, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<3, NDHWGC, @@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i ConvFwd1x1P0, Intrawave, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances(instances, device_grouped_conv_fwd_xdl_bf16_mem_instances<3, NDHWGC, @@ -52,7 +52,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_i ConvFwd1x1S1P0, Intrawave, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 6552f26f88..71f303f3dd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances) + AddClamp>>>& instances) { add_device_operation_instances( instances, @@ -32,7 +32,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh NDHWGK, ConvFwdDefault, Tuple, - AddRelu>{}); + AddClamp>{}); add_device_operation_instances( instances, device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, @@ -42,7 +42,7 @@ void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndh NDHWGK, ConvFwd3x3, Tuple, - AddRelu>{}); + AddClamp>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt deleted file mode 100644 index afdddfec70..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -# ONLY XDL_KERNELS -set(GROUPED_CONV3D_FWD - xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp - - xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - - xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp - - xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp - - xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp -) - -add_instance_library(device_grouped_conv3d_fwd_bias_relu_instance ${GROUPED_CONV3D_FWD}) diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp similarity index 96% rename from profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp rename to profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 9d38263d4e..3ef9f4505d 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -11,7 +11,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -35,19 +35,22 @@ template -bool profile_grouped_conv_fwd_bias_relu_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - const ck::utils::conv::ConvParam& conv_param) +bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) { using InElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; - using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + using OutElementOp = ck::tensor_operation::element_wise::AddClamp; + + const float floor = 0.f; + const float ceil = 256.f; const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{}; + const auto out_element_op = OutElementOp{floor, ceil}; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5ea61d2dfc..6692f55b5f 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -251,7 +251,7 @@ add_subdirectory(reduce) add_subdirectory(convnd_fwd) add_subdirectory(convnd_bwd_data) add_subdirectory(grouped_convnd_fwd) -add_subdirectory(grouped_convnd_fwd_bias_relu) +add_subdirectory(grouped_convnd_fwd_bias_clamp) add_subdirectory(grouped_convnd_bwd_weight) add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 3e7296b1eb..fc04af5cdb 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -# Currently ck_tile is only built on gfx94/gfx95 +# Currently ck_tile_gemm is only built on gfx94/gfx95 set(EXAMPLE_GEMM_COMPILE_OPTIONS "") if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) @@ -12,8 +12,6 @@ list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -enable-noalias-to-md-conversion=0 ) -if(CK_USE_OCP_FP8) - list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) @@ -25,4 +23,3 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") else() message("Skipping ck_tile_gemm tests for current target") endif() -endif() diff --git a/test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt b/test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt new file mode 100644 index 0000000000..4630a37d33 --- /dev/null +++ b/test/grouped_convnd_fwd_bias_clamp/CMakeLists.txt @@ -0,0 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_grouped_convnd_fwd_bias_clamp test_grouped_convnd_fwd_bias_clamp.cpp) + target_link_libraries(test_grouped_convnd_fwd_bias_clamp PRIVATE utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) +endif() diff --git a/test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp b/test/grouped_convnd_fwd_bias_clamp/test_grouped_convnd_fwd_bias_clamp.cpp similarity index 88% rename from test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp rename to test/grouped_convnd_fwd_bias_clamp/test_grouped_convnd_fwd_bias_clamp.cpp index c508235d9c..7d5437d247 100644 --- a/test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp +++ b/test/grouped_convnd_fwd_bias_clamp/test_grouped_convnd_fwd_bias_clamp.cpp @@ -7,11 +7,11 @@ #include #include -#include "profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp" +#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -using AddRelu = ck::tensor_operation::element_wise::AddRelu; +using AddClamp = ck::tensor_operation::element_wise::AddClamp; template class TestGroupedConvndFwd : public ::testing::Test @@ -32,16 +32,16 @@ class TestGroupedConvndFwd : public ::testing::Test bool pass = true; for(auto& param : conv_params) { - pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_relu_impl( + pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( true, // do_verification 1, // init_method: integer value false, // do_log diff --git a/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt b/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt deleted file mode 100644 index 680a92b19c..0000000000 --- a/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -if(GPU_TARGETS MATCHES "gfx9") - add_gtest_executable(test_grouped_convnd_fwd_bias_relu test_grouped_convnd_fwd_bias_relu.cpp) - target_link_libraries(test_grouped_convnd_fwd_bias_relu PRIVATE utility device_grouped_conv2d_fwd_bias_relu_instance device_grouped_conv3d_fwd_bias_relu_instance) -endif() diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index bc613a931e..01b064ea98 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -1,43 +1,60 @@ - # generate a list of kernels, but not actually emit files at config stage execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${CMAKE_CURRENT_BINARY_DIR} - --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + # --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json --list_blobs - RESULT_VARIABLE ret -) -set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS - ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") + message( FATAL_ERROR "Fail to list kernels via Python. ${ret}") endif() file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) +set(GEMM_CODEGEN_CPP_FILES "") +set(GEMM_CODEGEN_HPP_FILES "") + +foreach(blob ${GEMM_CODEGEN_BLOBS}) + string(STRIP "${blob}" stripped_blob) + + if(stripped_blob MATCHES "\\.cpp$") + list(APPEND GEMM_CODEGEN_CPP_FILES "${stripped_blob}") + elseif(stripped_blob MATCHES "\\.hpp$") + list(APPEND GEMM_CODEGEN_HPP_FILES "${stripped_blob}") + endif() +endforeach() + add_custom_command( OUTPUT ${GEMM_CODEGEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py --working_path ${CMAKE_CURRENT_BINARY_DIR} - --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + # --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json --gen_blobs - DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py - ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt - ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json ) -set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") -message("adding example ${EXECUTABLE_GEMM_INSTANCE}") +add_library(gemm_template_instances OBJECT EXCLUDE_FROM_ALL ${GEMM_CODEGEN_CPP_FILES}) +# Explicitly set LINKER_LANGUAGE to avoid build config failures with Ninja. +set_target_properties(gemm_template_instances PROPERTIES LINKER_LANGUAGE CXX) +target_include_directories(gemm_template_instances PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(gemm_template_instances PRIVATE ${GEMM_CODEGEN_HPP_FILES}) + +set(BENCHMARK_GEMM_EXECUTABLE "benchmark_gemm") +message("adding example ${BENCHMARK_GEMM_EXECUTABLE}") -# use build as include directory include_directories(${CMAKE_CURRENT_BINARY_DIR}) -add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp) -target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS}) + +add_library(gemm_host_api INTERFACE EXCLUDE_FROM_ALL) +target_include_directories(gemm_host_api INTERFACE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(gemm_host_api INTERFACE ${GEMM_CODEGEN_HPP_FILES} gemm_host_api.hpp) +target_link_libraries(gemm_host_api INTERFACE gemm_template_instances) + +add_executable(${BENCHMARK_GEMM_EXECUTABLE} EXCLUDE_FROM_ALL benchmark_gemm.cpp) +target_include_directories(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE benchmark_gemm.hpp gemm_profiler.hpp) +target_link_libraries(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE gemm_host_api) set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS) @@ -46,6 +63,6 @@ list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS -Wno-float-equal --offload-compress) -target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS}) +target_compile_options(${BENCHMARK_GEMM_EXECUTABLE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS}) set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md index f7d86e90fe..db624e576e 100644 --- a/tile_engine/ops/gemm/README.md +++ b/tile_engine/ops/gemm/README.md @@ -4,10 +4,11 @@ CK Tile Engine GEMM is used to generate and run GEMM kernels with different comb # Kernel Configurations -Kernel parameters are specified in the `instance_combination.json` file, including matrix layouts, data types, padding settings, pipelines, schedulers, epilogues, and numerical values for tile and warp sizes. +User can provide kernel configuration such as tile size, warp size, padding, pipeline, scheduler and epilogue in the config file with limited values. For reference please see `./configs/user_provided_config.json`. -Given a valid set of values, tile_engine_gemm will automatically iterate over all possible combinations of BlockTile and WarpTile sizes, as well as the specified pipelines, schedulers, and epilogues from `./configs/instance_combination.json`, and build the corresponding kernels. +The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json` +If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark. ## Build Instructions ``` bash @@ -16,41 +17,45 @@ mkdir build && cd build # build composable kernel sh ../script/cmake-ck-dev.sh ../ # replace with the appropriate architecture (example gfx942) or leave blank # generate the executable -make tile_engine_gemm -j +make benchmark_gemm -j ``` -`tile_engine_gemm` will be located in the `./bin/` directory. +`benchmark_gemm` will be located in the `./bin/` directory. + +`benchmark_gemm` must be rebuilt everytime if configuration file is modified. -_`tile_engine_gemm` must be rebuilt everytime `instance_combination.json` is modified._ ``` bash -rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild +rm -rf tile_engine/ && make benchmark_gemm -j # rebuild ``` -## tile_engine_gemm inputs +## benchmark_gemm inputs ``` + -m The value for m dimension. Default is 3840. + -n The value for n dimension. Default is 4096. + -k The value for k dimension. Default is 2048. + -stride_a The stride value for tensor A. Default is 0. + -stride_b The stride value for tensor B. Default is 0. + -stride_c The stride value for tensor C Default is 0. + -split_k The split value for k dimension. Default is 1. + -v The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 2, validation on GPU. + -log Wether output kernel instance information or not. Possible values are true or false. Default is false. + -warmup The number of iterations before benchmark the kernel. Default is 50. + -repeat The number of iterations to benchmark the kernel. Default is 100. + -timer Whether if the timer is gpu timer or not. Possible values are true or false. Default is true. + -init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random. + -flush_cache To flush cache in between different runs.Possible values are true or false. Default is false. + -rotating_count count to flush cache. Default is 5. + -metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency. + -csv_filename The filename of benchmark result. Default is gemm_kernel. + -structured_sparsity whether use sparsity kernel or not. Possible values are true or false. Default is false. + -pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3. + -epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle. + -pad_m Whether pad or not in m direction. Possible values are true or false. Default is false. + -pad_n Whether pad or not in n direction. Possible values are true or false. Default is false. + -pad_k Whether pad or not in k direction. Possible values are true or false. Default is false. - -m m dimension (default:3840) - -n n dimension (default:4096) - -k k dimension (default:2048) - -stride_a Tensor A stride (default:0) - -stride_b Tensor B stride (default:0) - -stride_c Tensor C stride (default:0) - -split_k SplitK value (default:1) - -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) - -warmup Number of iterations before benchmark the kernel (default:50) - -repeat Number of iterations to benchmark the kernel (default:100) - -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) --structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0) - -pipeline possible values are: compv3, compv4, mem (default:compv3) - -scheduler possible values are: intrawave, interwave (default:intrawave) - -epilogue possible values are: cshuffle, default (default:cshuffle) - -pad_m Pad in m direction - true/false (default:false) - -pad_n Pad in n direction - true/false (default:false) - -pad_k Pad in k direction - true/false (default:false) - -Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json +Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json ``` -Note: In `./configs/instance_combination.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above. +Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above. ## Example @@ -86,7 +91,7 @@ The following JSON file specifies parameters used to generate and build GEMM ker At runtime, a specific subset of the generated kernels can be selected using command-line arguments. ``` bash -./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default +./bin/benchmark_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default ``` The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings. diff --git a/tile_engine/ops/gemm/benchmark_gemm.cpp b/tile_engine/ops/gemm/benchmark_gemm.cpp new file mode 100644 index 0000000000..db2b648437 --- /dev/null +++ b/tile_engine/ops/gemm/benchmark_gemm.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "gemm_profiler.hpp" +#include "benchmark_gemm.hpp" + +void benchmark_gemm(const ck_tile::ArgParser& arg_parser) +{ + GemmProblem gemm_problem{arg_parser.get_int("split_k"), + arg_parser.get_int("m"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("stride_a"), + arg_parser.get_int("stride_b"), + arg_parser.get_int("stride_c"), + DataTypeTraits::name, + DataTypeTraits::name, + DataTypeTraits::name, + DataTypeTraits::name, + ALayout::name, + BLayout::name, + CLayout::name, + arg_parser.get_bool("structured_sparsity")}; + + Setting setting{arg_parser.get_int("warmup"), + arg_parser.get_int("repeat"), + arg_parser.get_bool("timer"), + arg_parser.get_int("verify"), + arg_parser.get_int("init"), + arg_parser.get_bool("log"), + arg_parser.get_str("csv_filename"), + arg_parser.get_bool("flush_cache"), + arg_parser.get_int("rotating_count")}; + + auto& profiler = GemmProfiler::instance(setting); + + try + { + auto kernel_func = get_kernel_func_by_trait(arg_parser); + profiler.benchmark(gemm_problem, kernel_func); + profiler.select_best_instance(static_cast(arg_parser.get_int("metric"))); + } + catch(const std::exception& e) + { + std::cerr << "Benchmark failed: " << e.what() << std::endl; + } +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + benchmark_gemm(parser); + return 0; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/benchmark_gemm.hpp new file mode 100644 index 0000000000..459a40b080 --- /dev/null +++ b/tile_engine/ops/gemm/benchmark_gemm.hpp @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "gemm_host_api.hpp" + +enum class Metric +{ + LATENCY = 0, + TFLOPS = 1, + BANDWIDTH = 2 +}; + +inline constexpr auto get_metric_name(Metric m) +{ + switch(m) + { + case Metric::LATENCY: return "latency"; + case Metric::TFLOPS: return "tflops"; + case Metric::BANDWIDTH: return "bandwidth"; + default: throw std::invalid_argument("Unsupported metric type"); + } +} + +struct GemmProblem +{ + int split_k_; + int m_, n_, k_; + int stride_a_, stride_b_, stride_c_; + + std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_; + std::string layout_a_, layout_b_, layout_c_; + + bool structured_sparsity_; + + friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem) + { + os << "{\n" + << " \"split_k\":" << problem.split_k_ << ",\n" + << " \"m\":" << problem.m_ << ",\n" + << " \"n\":" << problem.n_ << ",\n" + << " \"k\":" << problem.k_ << ",\n" + << " \"stride_a\":" << problem.stride_a_ << ",\n" + << " \"stride_b\":" << problem.stride_b_ << ",\n" + << " \"stride_c\":" << problem.stride_c_ << ",\n" + << " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n" + << " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n" + << " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n" + << " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n" + << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" + << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" + << " \"layout_c\":\"" << problem.layout_c_ << "\"\n" + << " \"structured_sparsity\":\"" << problem.structured_sparsity_ << "\"\n" + << "}"; + return os; + } +}; + +struct PerformanceResult +{ + double latency_; + double tflops_; + double bandwidth_; + + static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m) + { + switch(m) + { + case Metric::LATENCY: return a.latency_ < b.latency_; + case Metric::TFLOPS: return a.tflops_ > b.tflops_; + case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_; + default: throw std::invalid_argument("Unsupported metric type"); + } + } + + friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result) + { + os << "{\n" + << " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_ + << ",\n" + << " \"tflops(TFlops)\": " << result.tflops_ << ",\n" + << " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n" + << "}"; + return os; + } +}; + +struct KernelInstance +{ + std::string name_; + GemmProblem problem_; + PerformanceResult perf_result_; + + static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m) + { + return PerformanceResult::compare(a.perf_result_, b.perf_result_, m); + } + + friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) + { + os << "{\n" + << " \"name\": \"" + << "{\n" + << obj.name_ << "\n}" + << "\",\n" + << " \"problem\": \"" << obj.problem_ << "\",\n" + << " \"perf_result\": " << obj.perf_result_ << "\n" + << "}"; + return os; + } +}; + +struct Setting +{ + int n_warmup_; + int n_repeat_; + bool is_gpu_timer_; + int verify_; + int init_method_; + bool log_; + std::string csv_filename_; + bool flush_cache_; + int rotating_count_; +}; + +inline std::string get_rocm_version() +{ + std::ifstream version_file("/opt/rocm/.info/version"); + if(version_file.is_open()) + { + std::string version; + std::getline(version_file, version); + return version; + } + return "Unknown"; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +/// @brief Function to compare the results of the device and host computations +bool compare(ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; + + return pass; +} + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU +void gemm_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) +{ + if(verify == 1) + { + c_m_n_host_result.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_result); + } + else if(verify == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); + c_m_n_host_result.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); + } +} diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py new file mode 100644 index 0000000000..a8955cec91 --- /dev/null +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -0,0 +1,239 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# -*- coding: utf-8 -*- + +""" +Mappings and utility functions for kernel code generation. +""" + +import subprocess +import re +from functools import lru_cache + +DATA_TYPE_MAP = {'fp32': 'float', + 'fp16': 'ck_tile::half_t', + 'bf16': 'ck_tile::bf16_t', + 'int8': 'ck_tile::int8_t', + 'fp8': 'ck_tile::fp8_t', + 'bf8': 'ck_tile::bf8_t', + 'int4': 'ck_tile::pk_int4_t' + } + +LAYOUT_MAP = {'r': 'ck_tile::tensor_layout::gemm::RowMajor', + 'c': 'ck_tile::tensor_layout::gemm::ColumnMajor'} + +DEFAULT_EPILOGUE = """ + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; +""" + +CSHUFFLE_EPILOGUE = """ + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; +""" +HOT_LOOP_FALSE = """ + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); + } +""" +RUN_MEM = """ + // Handle One and Full cases directly + if (tail_num == ck_tile::TailNumber::One) { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } else if (tail_num == ck_tile::TailNumber::Full) { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + auto check_tail = [&](auto... TNs) { + ([&]{ + if constexpr(BaseGemmPipeline::PrefetchStages > static_cast(decltype(TNs)::value)) { + if(tail_num == decltype(TNs)::value) { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + }(), ...); + }; + + check_tail( + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{} + ); +""" + +RUN_COMPV3 = """ + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even."); + } +""" + +RUN_COMPV4 = """ + if(tail_num == ck_tile::TailNumber::Three) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +""" + + +PIPELINE_MAP = {'mem': ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], + 'compv3': ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], + 'compv4': ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} + +SCHEDULER_MAP = {'interwave': 'ck_tile::GemmPipelineScheduler::Interwave', + 'intrawave': 'ck_tile::GemmPipelineScheduler::Intrawave'} + +EPILOGUE_MAP = {'default': DEFAULT_EPILOGUE, + 'cshuffle': CSHUFFLE_EPILOGUE} + +HOT_LOOP_TRUE = {'mem': RUN_MEM, + 'compv3': RUN_COMPV3, + 'compv4': RUN_COMPV4} + + +def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)] + + +# To Do: add some more supported combinations +warp_tile_supported_combinations = { + "gfx90a": { + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]] + }, + "gfx942": { + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] + }, + "gfx950": { + 'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + 'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] + } +} + +# To Do: remove some unsupported combinations +trait_unsupported_combinations = { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave") +} + + +def element_size(data_type: str) -> float: + """Calculate the size (in bytes) of a single element for given data type.""" + data_type = data_type.lower() + if data_type in {'fp16', 'bf16'}: + return 2 + elif data_type in {'int8', 'fp8', 'bf8'}: + return 1 + elif data_type == 'int4': + return 0.5 + else: + raise ValueError(f"Unsupported data type: {data_type}") + + +GPU_NAME_PATTERN = re.compile(r'Name:\s*(gfx\d+\w*)') + + +@lru_cache(maxsize=1) +def get_gpu_name_by_id(gpu_id: int = 0) -> str: + """Retrieve GPU name (e.g. gfx90a) by device ID""" + try: + output = subprocess.check_output( + ["rocminfo"], + text=True, + stderr=subprocess.PIPE, + timeout=5 + ) + if matches := GPU_NAME_PATTERN.finditer(output): + gpu_list = [m.group(1) for m in matches] + return gpu_list[gpu_id] if gpu_id < len(gpu_list) else "" + + return "" + + except subprocess.CalledProcessError as e: + print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}") + except FileNotFoundError: + print("ROCm tools not installed (requires rocminfo)") + except subprocess.TimeoutExpired: + print("GPU query timeout (5s)") + except Exception as e: + print(f"GPU detection error: {str(e)}") + + return "" diff --git a/tile_engine/ops/gemm/configs/default_config.json b/tile_engine/ops/gemm/configs/default_config.json new file mode 100644 index 0000000000..09fe3b83ac --- /dev/null +++ b/tile_engine/ops/gemm/configs/default_config.json @@ -0,0 +1,130 @@ +{ + "problem": { + "layout_a": { + "values": [ + "r" + ] + }, + "layout_b": { + "values": [ + "c" + ] + }, + "layout_c": { + "values": [ + "r" + ] + }, + "datatype_a": { + "values": [ + "fp16" + ] + }, + "datatype_b": { + "values": [ + "fp16" + ] + }, + "datatype_c": { + "values": [ + "fp16" + ] + } + }, + "tile_config": { + "tile_m": { + "max": 512, + "min": 64, + "step": 64, + "exclude": [] + }, + "tile_n": { + "max": 512, + "min": 64, + "step": 32, + "exclude": [] + }, + "tile_k": { + "max": 512, + "min": 64, + "step": 64, + "exclude": [] + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16, + 32 + ] + }, + "warp_tile_n": { + "values": [ + 16, + 32 + ] + }, + "warp_tile_k": { + "values": [ + 8, + 16, + 32, + 64, + 128 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv4", + "compv3", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "default", + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json deleted file mode 100644 index b497513efa..0000000000 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ /dev/null @@ -1,62 +0,0 @@ -{ - "architecture": { - "values": ["gfx90a"] - }, - "layout_a": { - "values": ["r"] - }, - "layout_b": { - "values": ["c"] - }, - "layout_c": { - "values": ["r"] - }, - "datatype": { - "values": ["fp16"] - }, - "tile_m": { - "values": [256] - }, - "tile_n": { - "values": [256] - }, - "tile_k": { - "values": [32] - }, - "warp_m": { - "values": [2] - }, - "warp_n": { - "values": [2] - }, - "warp_k": { - "values": [1] - }, - "warp_tile_m": { - "values": [32] - }, - "warp_tile_n": { - "values": [32] - }, - "warp_tile_k": { - "values": [16] - }, - "kPadM": { - "values": [false] - }, - "kPadN": { - "values": [false] - }, - "kPadK": { - "values": [false] - }, - "pipeline": { - "values": ["compv3", "compv4", "mem"] - }, - "scheduler": { - "values": ["intrawave", "interwave"] - }, - "epilogue": { - "values": ["default", "cshuffle"] - } -} diff --git a/tile_engine/ops/gemm/configs/user_provided_config.json b/tile_engine/ops/gemm/configs/user_provided_config.json new file mode 100644 index 0000000000..6a6e726e40 --- /dev/null +++ b/tile_engine/ops/gemm/configs/user_provided_config.json @@ -0,0 +1,116 @@ +{ + "problem": { + "layout_a": { + "values": [ + "r" + ] + }, + "layout_b": { + "values": [ + "c" + ] + }, + "layout_c": { + "values": [ + "r" + ] + }, + "datatype_a": { + "values": [ + "fp16" + ] + }, + "datatype_b": { + "values": [ + "fp16" + ] + }, + "datatype_c": { + "values": [ + "fp16" + ] + } + }, + "tile_config": { + "tile_m": { + "values": [ + 128 + ] + }, + "tile_n": { + "values": [ + 128 + ] + }, + "tile_k": { + "values": [ + 32 + ] + }, + "warp_m": { + "values": [ + 2 + ] + }, + "warp_n": { + "values": [ + 2 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 32 + ] + }, + "warp_tile_n": { + "values": [ + 32 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "default", + "cshuffle" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + } + } +} \ No newline at end of file diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp deleted file mode 100755 index a5447cd658..0000000000 --- a/tile_engine/ops/gemm/gemm_host_api.cpp +++ /dev/null @@ -1,192 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck_tile/host.hpp" -#include "gemm_common.hpp" -#include "gemm_dispatcher.hpp" -#include "gemm_host_api.hpp" - -void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, - bool structured_sparsity, - KernelTraits& trait, - ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& stream) -{ - return GemmDispatcher::dispatch(c_m_n_dev_buf, - c_m_n_host_result, - c_m_n_dev_result, - verify, - structured_sparsity, - trait, - args, - stream); -} - -template -void run(const ck_tile::ArgParser& arg_parser) -{ - const ALayout a_layout = ALayout{}; - const BLayout b_layout = BLayout{}; - - ck_tile::index_t kbatch = arg_parser.get_int("split_k"); - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); - int verify = arg_parser.get_int("v"); - ck_tile::index_t init_method = arg_parser.get_int("init"); - bool structured_sparsity = arg_parser.get_bool("structured_sparsity"); - - stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); - stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); - - ck_tile::HostTensor a_m_k( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); - ck_tile::HostTensor c_m_n_dev_result( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - - if(init_method == 0) - { - ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); - } - else if(init_method == 1) - { - ck_tile::FillMonotonicSeq{}(a_m_k); - ck_tile::FillMonotonicSeq{}(b_k_n); - } - else if(init_method == 2) - { - ck_tile::FillConstant{static_cast(1)}(a_m_k); - ck_tile::FillConstant{static_cast(1)}(b_k_n); - } - else - { - a_m_k.SetZero(); - b_k_n.SetZero(); - } - - if(structured_sparsity) - { - ck_tile::AdjustToStructuredSparsity{}(a_m_k); - } - - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - - if constexpr(std::is_same_v) - { - // Permute vector pk_i4x4 data for device implementation - ck_tile::HostTensor b_k_n_dev = b_k_n; - // permute_tensor_b(b_k_n_dev); - permute_vectors_i4x4_b(b_k_n_dev); - b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); - } - else - { - b_k_n_dev_buf.ToDevice(b_k_n.data()); - } - - a_m_k_dev_buf.ToDevice(a_m_k.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); - - ck_tile::GemmHostArgs gemm_args; - gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - gemm_args.k_batch = kbatch; - gemm_args.M = M; - gemm_args.N = N; - gemm_args.K = K; - gemm_args.stride_A = stride_A; - gemm_args.stride_B = stride_B; - gemm_args.stride_C = stride_C; - - KernelTraits trait; - trait.pipeline = arg_parser.get_str("pipeline"); - trait.scheduler = arg_parser.get_str("scheduler"); - trait.epilogue = arg_parser.get_str("epilogue"); - trait.kPadM = arg_parser.get_bool("pad_m"); - trait.kPadN = arg_parser.get_bool("pad_n"); - trait.kPadK = arg_parser.get_bool("pad_k"); - - std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name - << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name << std::endl; - - ck_tile::HostTensor c_m_n_host_result( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - - if(verify) - { - gemm_host_reference(verify, - a_m_k, - b_k_n, - c_m_n_host_result, - a_m_k_dev_buf, - b_k_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C); - } - - gemm_kernel_launch(c_m_n_dev_buf, - c_m_n_host_result, - c_m_n_dev_result, - verify, - structured_sparsity, - trait, - gemm_args, - ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - - return; -} - -int main(int argc, char* argv[]) -{ - try - { - auto [result, parser] = create_args(argc, argv); - if(!result) - return EXIT_FAILURE; - run(parser); - return 0; - } - catch(const std::exception& e) - { - std::cerr << "Error: " << e.what() << "\n"; - return EXIT_FAILURE; - } -} diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp old mode 100755 new mode 100644 index 579d2770db..b3aab6ad92 --- a/tile_engine/ops/gemm/gemm_host_api.hpp +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -1,16 +1,15 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include +#pragma once #include -#include -#include #include #include -#include "ck_tile/ops/gemm.hpp" -#pragma once +#include "ck_tile/host.hpp" +#include "gemm_dispatcher.hpp" +#include "gemm_common.hpp" template struct DataTypeTraits; @@ -57,24 +56,6 @@ struct DataTypeTraits static constexpr const char* name = "pk_int4_t"; }; -/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a -/// specific kernel instance based on the provided settings. -struct KernelTraits -{ - /// @brief The name of the pipeline. - std::string pipeline; - /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). - std::string scheduler; - /// @brief The name of the epilogue (e.g., "cshuffle", "default"). - std::string epilogue; - /// @brief Indicates whether padding is applied to the M dimension. - bool kPadM; - /// @brief Indicates whether padding is applied to the N dimension. - bool kPadN; - /// @brief Indicates whether padding is applied to the K dimension. - bool kPadK; -}; - template static constexpr inline auto is_row_major(Layout layout_) { @@ -82,49 +63,76 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } -template -auto calculate_rtol_atol(const ck_tile::index_t K, - const ck_tile::index_t kbatch, - const float max_accumulated_value) -{ - using ComputeType = - std::conditional_t; - // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); - // Calculate error due to split_k accumulation - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); - // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); -} - inline auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("split_k", "1", "splitK value") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("init", "0", "0:random, 1:linear, 2:constant(1)") - .insert("structured_sparsity", "0", "0:false, 1:true") - .insert("pipeline", "compv3", "compv3, compv4, mem") - .insert("scheduler", "intrawave", "intrawave, interwave") - .insert("epilogue", "cshuffle", "cshuffle, default") - .insert("pad_m", "false", "true, false") - .insert("pad_n", "false", "true, false") - .insert("pad_k", "false", "true, false"); + arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.") + .insert("n", "4096", "The value for n dimension. Default is 4096.") + .insert("k", "2048", "The value for k dimension. Default is 2048.") + .insert("stride_a", "0", "The stride value for tensor A. Default is 0.") + .insert("stride_b", "0", "The stride value for tensor B. Default is 0.") + .insert("stride_c", "0", "The stride value for tensor C Default is 0.") + .insert("split_k", "1", "The split value for k dimension. Default is 1.") + .insert("verify", + "2", + "The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 " + "for validation on GPU. Default is 2, validation on GPU.") + .insert("log", + "false", + "Wether output kernel instance information or not. Possible values are true or " + "false. Default is false") + .insert( + "warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.") + .insert( + "repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.") + .insert("timer", + "true", + "Whether if the timer is gpu timer or not. Possible values are false or true. " + "Default is true.") + .insert("init", + "0", + "The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 " + "for constant(1). Default is 0, random.") + .insert("flush_cache", + "false", + "To flush cache, possible values are true or false. " + "Default is false.") + .insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.") + .insert("metric", + "0", + "Metric with which to measure kernel performance. Set to 0 for latency, 1 for " + "tflops, or 2 for bandwidth. Default is 0, latency.") + .insert("csv_filename", + "gemm_kernel", + "The filename of benchmark result. Default is gemm_kernel.") + .insert("structured_sparsity", + "false", + "Whether use sparsity kernel or not. Possible values are true or false. Default is " + "false") + .insert( + "pipeline", + "compv3", + "The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.") + .insert("scheduler", + "intrawave", + "The type of pipeline. Possible values are compv3, compv4 or mem. Default is " + "compv3.") + .insert( + "epilogue", + "cshuffle", + "The type of epilogue. Possible values are cshuffle or default. Default is csshuffle.") + .insert("pad_m", + "false", + "Whether pad or not in m direction. Possible values are true or false. Default is " + "false.") + .insert("pad_n", + "false", + "Whether pad or not in n direction. Possible values are true or false. Default is " + "false.") + .insert("pad_k", + "false", + "Whether pad or not in k direction. Possible values are true or false. Default is " + "false."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -185,79 +193,17 @@ void permute_vectors_i4x4_b(Tensor& tensor) } } -/// @brief Function to compare the results of the device and host computations -void compare(ck_tile::index_t K, - ck_tile::index_t kbatch, - ck_tile::HostTensor& c_m_n_dev_result, - ck_tile::HostTensor& c_m_n_host_result) +auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser) { - const float max_accumulated_value = - *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, kbatch, max_accumulated_value); - bool pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_host_result, - "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); + KernelTraits trait; + trait.pipeline = arg_parser.get_str("pipeline"); + trait.scheduler = arg_parser.get_str("scheduler"); + trait.epilogue = arg_parser.get_str("epilogue"); + trait.pad_m = arg_parser.get_bool("pad_m"); + trait.pad_n = arg_parser.get_bool("pad_n"); + trait.pad_k = arg_parser.get_bool("pad_k"); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; - std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; -} - -/// @brief Function to get the kernel output with reference implementation on CPU/GPU -template -void gemm_host_reference(int verify, - ck_tile::HostTensor& a_m_k, - ck_tile::HostTensor& b_k_n, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C) -{ - if(verify == 1) - { - c_m_n_host_result.SetZero(); - - ck_tile::reference_gemm( - a_m_k, b_k_n, c_m_n_host_result); - } - else if(verify == 2) - { - if constexpr(std::is_same_v) - { - // Restore input for B for gpu reference - b_k_n_dev_buf.ToDevice(b_k_n.data()); - } - - ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); - c_m_n_host_result.SetZero(); - c_m_n_gpu_buf_ref.SetZero(); - - ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); - BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); - CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); - - ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); - - c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); - } + bool structured_sparsity = arg_parser.get_bool("structured_sparsity"); + + return GemmDispatcher::dispatch(structured_sparsity, trait); } diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index dd8b4d1157..ea7fa4e67c 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -1,385 +1,199 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -# generate kernel instances to speed up compilation + +# -*- coding: utf-8 -*- + +""" +generate kernel instances to speed up compilation +""" import argparse -from enum import IntEnum -from pathlib import Path -import sys -from typing import List, Optional, Dict, Any -import functools import itertools -import copy -import json -from dataclasses import dataclass - -DATA_TYPE_MAP = {'fp32' : 'float', - 'fp16' : 'ck_tile::half_t', - 'bf16' : 'ck_tile::bf16_t', - 'int8' : 'ck_tile::int8_t', - 'fp8' : 'ck_tile::fp8_t', - 'bf8' : 'ck_tile::bf8_t', - 'int4' : 'ck_tile::pk_int4_t' - } +from pathlib import Path +from typing import List, Optional +from json_config import GemmConfig, RangeConfigParam +from codegen_utils import ( + DATA_TYPE_MAP, + LAYOUT_MAP, + DEFAULT_EPILOGUE, + CSHUFFLE_EPILOGUE, + HOT_LOOP_FALSE, + RUN_MEM, + RUN_COMPV3, + RUN_COMPV4, + PIPELINE_MAP, + SCHEDULER_MAP, + EPILOGUE_MAP, + HOT_LOOP_TRUE, + BOOL_MAP, + warp_tile_supported_combinations, + trait_unsupported_combinations, + element_size, + get_gpu_name_by_id +) +import logging +import time -LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', - 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} - - -warp_tile_combinations_map = { - "gfx90a": { - 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8': [[32, 32, 16], [32, 32, 32]], - 'bf8': [[32, 32, 16], [32, 32, 32]] - }, - "gfx942": { - 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], - 'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] - }, - "gfx950": { - 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], - 'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], - 'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] - } - } - -def sizeOf(data_type): - if data_type == 'fp16' or data_type == 'bf16': - return 2 - elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8': - return 1 - elif data_type == 'int4': ## TODO:: needs to confirm - return 0.5 - else: - return 4 - -DEFAULT_EPILOGUE = """ - using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; -""" - -CSHUFFLE_EPILOGUE = """ - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; -""" -HOT_LOOP_FALSE = """ - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); - } -""" -RUN_MEM = """ - // Handle One and Full cases directly - if (tail_num == ck_tile::TailNumber::One) { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } else if (tail_num == ck_tile::TailNumber::Full) { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - // Variadic call using fold expression - auto check_tail = [&](auto... TNs) { - (try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...); - }; - - check_tail( - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{}, - ck_tile::integral_constant{} - ); -""" - -RUN_COMPV3 = """ - if(tail_num == ck_tile::TailNumber::Full) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even."); - } -""" - -RUN_COMPV4 = """ - if(tail_num == ck_tile::TailNumber::Three) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } -""" - - -PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], - 'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], - 'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} - -SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave', - 'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'} - -EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE, - 'cshuffle' : CSHUFFLE_EPILOGUE} - -HOT_LOOP_TRUE = {'mem' : RUN_MEM, - 'compv3' : RUN_COMPV3, - 'compv4' : RUN_COMPV4} - - -def BOOL_MAP(b_) -> str: - if b_: - return 'true' - else: - return 'false' - -@dataclass -class GemmConfig: - def __init__(self, config_data): - self.matrix_cfg : Dict[str, Any] = {} - self.impl_cfg : Dict[str, Any] = {} - for key, value in config_data.items(): - if key in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]: - self.matrix_cfg[key] = value - else: - self.impl_cfg[key] = value - - @property - def architecture(self) -> str: - return self.matrix_cfg["architecture"]["values"][0] - - @property - def datatype(self) -> str: - return self.matrix_cfg["datatype"]["values"][0] - - @property - def layouts(self) -> List[str]: - return [ - self.matrix_cfg["layout_a"]["values"][0], - self.matrix_cfg["layout_b"]["values"][0], - self.matrix_cfg["layout_c"]["values"][0] - ] +logging.basicConfig(level=logging.INFO) class GemmCodeGenerator: - def __init__(self, output_dir: str, config: GemmConfig): + """GEMM (General Matrix Multiplication) code generator.""" + + def __init__(self, output_dir: str, + user_provided_config: Optional[GemmConfig] = None): self.output_dir = Path(output_dir) - if not self.output_dir.exists(): - self.output_dir.mkdir() + self.output_dir.mkdir(parents=True, exist_ok=True) - self.config = config - self.all_kernels = [] - self.unique_configs = [] - # Validate configurations - self._validate_config() + if user_provided_config is not None: + self.config = user_provided_config + else: + config_path = Path(__file__).resolve().parent / \ + "configs" / "default_config.json" + self.config = GemmConfig.from_json(config_path) - def _validate_config(self): - """Validate matrix and implementation configurations""" - # Matrix config validation - for param in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]: - if len(self.config.matrix_cfg[param]["values"]) != 1: - raise ValueError(f"Matrix config {param} must have exactly one value") - - # Implementation traits validation - required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k", - "warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline", - "epilogue", "scheduler", "kPadM", "kPadN", "kPadK"] - for param in required_params: - if not self.config.impl_cfg.get(param, {}).get("values"): - raise ValueError(f"Missing implementation parameter: {param}") + self.valid_trait_names: List[str] = [] + self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {} - def list_all(self): - """List all possible kernel configurations""" + def list_all_trait_names(self): + """List all possible kernel trait names into file.""" w_p = Path(self.output_dir) - list_p = w_p / 'gemm_instance_blobs.txt' - self._list_config_groups() - with list_p.open('w') as list_f: - list_f.write(str(w_p / ("gemm_common.hpp")) + "\n") - list_f.write(str(w_p / ("gemm_instances.hpp")) + "\n") - list_f.write(str(w_p / ("gemm_dispatcher.hpp")) + "\n") - for group in self.all_kernels: - list_f.write(str(w_p / ("gemm_" + group + ".hpp")) + "\n") - + file_path = w_p / 'gemm_instance_blobs.txt' + self._generate_all_traits() + self._get_valid_trait_tile_combinations() + # Write all file paths to the header file + with file_path.open('w') as f: + f.write(str(w_p / "gemm_common.hpp") + "\n") + f.write(str(w_p / "gemm_instances.hpp") + "\n") + f.write(str(w_p / "gemm_dispatcher.hpp") + "\n") + for trait in self.valid_trait_names: + f.write(str(w_p / f"gemm_{trait}.hpp") + "\n") + for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): + for tile in tile_valid_params: + for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile: + sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \ + self.config.problem.datatype_map['matrix_b'] == 'fp16' and \ + self.config.problem.datatype_map['matrix_c'] == 'fp16' and \ + ((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or + (warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32)) + if sparse: + f.write(str( + w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp") + "\n") + f.write(str( + w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp") + "\n") - def _list_config_groups(self): + def _generate_all_traits(self): + """Generate all possible kernel traits names.""" params = [ - ("pipeline", "pipeline"), - ("epilogue", "epilogue"), - ("scheduler", "scheduler"), - ("kPadM", "kPadM"), - ("kPadN", "kPadN"), - ("kPadK", "kPadK") - ] - + "pipeline", + "epilogue", + "scheduler", + "pad_m", + "pad_n", + "pad_k"] + # Generate all unique_combinations - _unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params])) + _unique = set(itertools.product(*[ + getattr(self.config.trait_config, param).values + for param in params + ])) + for combo in _unique: - config = {name: value for (_, name), value in zip(params, combo)} - pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values() - # To remove some unsupported combinations - unsupported_combination = [("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave")] - if (pipeline, epilogue, scheduler) not in unsupported_combination: - group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}" - self.all_kernels.append(group_name) - self.unique_configs.append(config) + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo + current_combination = (pipeline, epilogue, scheduler) - def generate_all(self): - self._generate_common_header() - self._generate_config_groups() - self._generate_dispatcher() - + if current_combination not in trait_unsupported_combinations: + trait_name = ( + f"{pipeline}_{epilogue}_{scheduler}_" + f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}" + ) + self.valid_trait_names.append(trait_name) + else: + logging.debug( + f"Invalid combination: {pipeline}-{epilogue}-{scheduler}" + ) - def _generate_common_header(self): - """Generate common header with datatypes and layout""" - self.ctype = self.config.datatype - self.atype = self.config.datatype - self.btype = self.config.datatype - if self.config.datatype in ['fp8', 'bf8']: - self.ctype = 'fp16' - elif self.config.datatype in ['int4']: - self.atype = 'fp16' - self.ctype = 'fp16' + def generate_all_instance_files(self): + """Generate all kernel instances files.""" + self._generate_common_header_file() + self._generate_all_trait_files() + self._generate_dispatcher_file() + + def _generate_common_header_file(self): + """Generate common header file with datatypes and layout.""" content = f"""// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once + #include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" // Data types -using ADataType = {DATA_TYPE_MAP[self.atype]}; -using BDataType = {DATA_TYPE_MAP[self.btype]}; +using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_a']]}; +using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_b']]}; using AccDataType = float; -using CDataType = {DATA_TYPE_MAP[self.ctype]}; +using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_c']]}; // Layout configurations -using ALayout = {LAYOUT_MAP[self.config.layouts[0]]}; -using BLayout = {LAYOUT_MAP[self.config.layouts[1]]}; -using CLayout = {LAYOUT_MAP[self.config.layouts[2]]}; +using ALayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_a']]}; +using BLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_b']]}; +using CLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_c']]}; """ - (self.output_dir / "gemm_common.hpp").write_text(content) - def _generate_config_groups(self): - """Generate implementation configuration groups""" - if not self.unique_configs: # Check if the list is empty - self._list_config_groups() - for config in self.unique_configs: - self._generate_config_group(**config) - self.generate_common_instances_header() + def _generate_all_trait_files(self): + """Generate all kernel traits into files.""" + if not self.valid_trait_names: + self._generate_all_traits() + self._get_valid_trait_tile_combinations() + for trait in self.valid_trait_names: + self._generate_trait_file(trait) + self._generate_instantiation_source_files() + self._generate_common_instance_header_file() - - def _generate_config_group(self, pipeline: str, epilogue: str, scheduler: str, - kPadM: bool, kPadN: bool, kPadK: bool): - """Generate a configuration group with all tile/warp combinations""" - group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}" - filename = f"gemm_{group_name}.hpp" + def _generate_trait_file(self, trait: str): + """Generate a trait with all tile/warp combinations.""" + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_") + filename = f"gemm_{trait}.hpp" content = f"""// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + #include "gemm_common.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/host.hpp" -namespace {group_name} {{ +namespace {trait} {{ """ # Add template struct with configuration - content += self._generate_kernel_struct(pipeline, epilogue, scheduler, kPadM, kPadN, kPadK) + content += self._generate_kernel_struct( + pipeline, epilogue, scheduler, pad_m, pad_n, pad_k) - content += f"\n}} // namespace {group_name}\n" + content += f"\n}} // namespace {trait}\n" (self.output_dir / filename).write_text(content) def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str, - kPadM: bool, kPadN: bool, kPadK: bool) -> str: - """Generate kernel struct template""" + pad_m: str, pad_n: str, pad_k: str) -> str: + """Generate the code block of kernel struct""" return f""" -template -void try_run(ck_tile::TailNumber tn) {{ - if constexpr (Pipeline::PrefetchStages > static_cast(TN) - 1) {{ - if (tn == TN) {{ - RunSplitk(ck_tile::bool_constant{{}}, - ck_tile::integral_constant{{}}); - }} - }} -}} + template struct GemmKernel {{ - static constexpr bool kPadM = {BOOL_MAP(kPadM)}; - static constexpr bool kPadN = {BOOL_MAP(kPadN)}; - static constexpr bool kPadK = {BOOL_MAP(kPadK)}; - - static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{ + static constexpr bool kPadM = {pad_m}; + static constexpr bool kPadN = {pad_n}; + static constexpr bool kPadK = {pad_k}; + + static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ static constexpr bool permuteA = false; static constexpr bool permuteB = false; static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; @@ -389,7 +203,7 @@ struct GemmKernel {{ static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; - using GemmShape = + using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence, @@ -403,22 +217,22 @@ struct GemmKernel {{ TileParitionerM01>; using Traits = - ck_tile::TileGemmTraits; + ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + ALayout, BLayout, CLayout, TransposeC, structured_sparsity>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; + using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; const ck_tile::index_t k_grain = args.k_batch * TileK; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{{0}}; @@ -428,7 +242,7 @@ struct GemmKernel {{ constexpr auto scheduler = {SCHEDULER_MAP[scheduler]}; constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; + using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; {EPILOGUE_MAP[epilogue]} using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -451,7 +265,7 @@ struct GemmKernel {{ throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); }} - if(s.log_level_ > 0) + if(stream.log_level_ > 0) {{ std::cout << "Launching kernel with args:" << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" @@ -459,9 +273,52 @@ struct GemmKernel {{ << std::endl; }} - ave_time = ck_tile::launch_kernel(s, + if(stream.flush_cache_) + {{ + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; + + auto is_row_major = [](auto layout_) {{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{{}}; + }}; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{{}}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{{}}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() {{ + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.c_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_)); + }}; + ave_time = ck_tile::launch_kernel_preprocess( + stream, + run_flush_cache, + ck_tile::make_kernel( + Kernel{{}}, grids, blocks, 0, kargs)); + }} + else{{ + ave_time = ck_tile::launch_kernel(stream, ck_tile::make_kernel( Kernel{{}}, grids, blocks, 0, kargs)); + }} return ave_time; }}; @@ -488,206 +345,333 @@ struct GemmKernel {{ return ave_time; }} - + static std::string get_name() {{ - return std::string("GemmKernel bool: - """Check if the tile configuration is valid for the given group""" - # Extract tile parameters + def is_tile_valid(self, tile: tuple, trait: str) -> bool: + """Check if the tile configuration is valid for the given trait.""" tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile + pipeline, *_ = trait.split("_") - # Extract the pipeline and epilogue from the group name - _, pipeline, epilogue, scheduler, *_ = group.split("_") + # Parameter validity check + invalid_params = [] + if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]: + invalid_params.append( + f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})") + if (warp_m * warp_tile_m) == 0: + invalid_params.append( + f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})") + if (warp_n * warp_tile_n) == 0: + invalid_params.append( + f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})") + if (warp_k * warp_tile_k) == 0: + invalid_params.append( + f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})") - if tile_m % (warp_m * warp_tile_m) == 0 and \ - tile_n % (warp_n * warp_tile_n) == 0 and \ - tile_k % (warp_k * warp_tile_k) == 0: - total_tile_in_lds = (tile_m * tile_k + tile_n * tile_k ) * sizeOf(self.config.datatype) - # Validate and append valid tile parameters - is_compv4 = pipeline == "compv4" - max_tile_size = pow(2, 16) if is_compv4 else pow(2, 15) + if invalid_params: + logging.debug( + f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. " + f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), " + f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})" + ) + return False - if total_tile_in_lds > max_tile_size: - raise ValueError(f'Total tile size should not exceed {max_tile_size / 1024}KB of LDS. ' - f'{tile_m} * {tile_n} * {tile_k} > {max_tile_size / 1024}KB') - arch = self.config.architecture - if [warp_tile_m, warp_tile_n, warp_tile_k] in warp_tile_combinations_map[arch][self.config.datatype]: - return True - return False + # Dimension alignment check + alignment_issues = [] + if tile_m % (warp_m * warp_tile_m) != 0: + alignment_issues.append( + f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}") + if tile_n % (warp_n * warp_tile_n) != 0: + alignment_issues.append( + f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}") + if tile_k % (warp_k * warp_tile_k) != 0: + alignment_issues.append( + f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}") - def _generate_dispatcher(self): - """Generate dispatch mechanism""" - content = """// SPDX-License-Identifier: MIT + if alignment_issues: + logging.debug( + f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. " + f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " + f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + ) + return False + + # LDS capacity verification + matrix_a_size = (tile_m * tile_k) * \ + pow(2, element_size(self.config.problem.datatype_map['matrix_a'])) + matrix_b_size = (tile_n * tile_k) * \ + pow(2, element_size(self.config.problem.datatype_map['matrix_b'])) + total_tile_in_lds = matrix_a_size + matrix_b_size + + max_tile_size = 2**16 if pipeline == "compv4" else 2**15 + if total_tile_in_lds > max_tile_size: + logging.debug( + f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > " + f"maximum allowed {max_tile_size:,}B ({max_tile_size/1024}KB). Breakdown:\n" + f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n" + f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B" + ) + return False + + # Warp combination validation + warp_tile_key = f"{self.config.problem.datatype_map['matrix_a']}_{self.config.problem.datatype_map['matrix_b']}_{self.config.problem.datatype_map['matrix_c']}" + current_combination = [warp_tile_m, warp_tile_n, warp_tile_k] + + gpu_name = get_gpu_name_by_id(0) + gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {}) + if not gpu_warp_tile_key: + logging.debug( + f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.") + return False + + allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, []) + if not allowed_combinations: + logging.debug( + f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.") + return False + + if current_combination not in allowed_combinations: + logging.debug( + f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. " + f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}" + ) + return False + + return True + + def _get_valid_trait_tile_combinations(self): + def get_tile_value(tile_param): return tile_param.generate_candidates( + ) if isinstance(tile_param, RangeConfigParam) else tile_param.values + + tile_group = list(itertools.product( + get_tile_value(self.config.tile_config.tile_m), + get_tile_value(self.config.tile_config.tile_n), + get_tile_value(self.config.tile_config.tile_k) + )) + + warp_group = list(itertools.product( + get_tile_value(self.config.tile_config.warp_m), + get_tile_value(self.config.tile_config.warp_n), + get_tile_value(self.config.tile_config.warp_k) + )) + + warp_tile_group = list(itertools.product( + get_tile_value(self.config.tile_config.warp_tile_m), + get_tile_value(self.config.tile_config.warp_tile_n), + get_tile_value(self.config.tile_config.warp_tile_k) + )) + + tile_params = { + t + w + wt + for t in tile_group + for w in warp_group + for wt in warp_tile_group + } + + for trait in self.valid_trait_names: + tile_valid_params = list( + filter(lambda t: self.is_tile_valid(t, trait), tile_params)) + if trait not in self.valid_trait_tile_combinations: + self.valid_trait_tile_combinations[trait] = [] + self.valid_trait_tile_combinations[trait].append(tile_valid_params) + + def _generate_instantiation_source_files(self): + """Generate kernel instance instantiation source files """ + for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): + for tile in tile_valid_params: + for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile: + content = f""" +// SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "gemm_common.hpp" -#include "gemm_instances.hpp" -#include "gemm_host_api.hpp" + + +#include "gemm_{trait}.hpp" + +""" + sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \ + self.config.problem.datatype_map['matrix_b'] == 'fp16' and \ + self.config.problem.datatype_map['matrix_c'] == 'fp16' and \ + ((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or + (warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32)) + if sparse: + sparse_content = content + f""" +template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>; +""" + (self.output_dir / + f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp").write_text(sparse_content) + + no_sparse_content = content + f""" +template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>; +""" + (self.output_dir / + f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp").write_text(no_sparse_content) + + def _generate_dispatcher_file(self): + """Generate the code block of dispatch mechanism.""" + content = """ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + #include #include #include +#include "gemm_common.hpp" +#include "gemm_instances.hpp" + +/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a +/// specific kernel instance based on the provided settings. +struct KernelTraits +{ + /// @brief The name of the pipeline. + std::string pipeline; + /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). + std::string scheduler; + /// @brief The name of the epilogue (e.g., "cshuffle", "default"). + std::string epilogue; + /// @brief Indicates whether padding is applied to the M dimension. + bool pad_m; + /// @brief Indicates whether padding is applied to the N dimension. + bool pad_n; + /// @brief Indicates whether padding is applied to the K dimension. + bool pad_k; +}; + struct GemmDispatcher { static auto& get_kernel_map() { // Use a static local variable - static std::unordered_map& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map; + static std::unordered_map< + std::string, + std::vector(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>> + kernel_map; return kernel_map; } static void init(bool structured_sparsity) { - auto& kernel_map = get_kernel_map(); - if(!kernel_map.empty()) return; + auto& kernel_map = get_kernel_map(); + if(!kernel_map.empty()) return; \n""" - # Add tile/warp instantiations - tile_params = set(itertools.product( - self.config.impl_cfg["tile_m"]["values"], - self.config.impl_cfg["tile_n"]["values"], - self.config.impl_cfg["tile_k"]["values"], - self.config.impl_cfg["warp_m"]["values"], - self.config.impl_cfg["warp_n"]["values"], - self.config.impl_cfg["warp_k"]["values"], - self.config.impl_cfg["warp_tile_m"]["values"], - self.config.impl_cfg["warp_tile_n"]["values"], - self.config.impl_cfg["warp_tile_k"]["values"] - )) - - for group in self.all_kernels: - content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& stream) {{ - if(structured_sparsity){{ // SMFMA""" - for tile in tile_params: - if self.is_tile_valid(tile, group): - sparse = self.atype == 'fp16' and \ - ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or - (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) + for trait, tile_valid_params in self.valid_trait_tile_combinations.items(): + content += f""" kernel_map["{trait}"] = {{""" + for _, tile in enumerate(tile_valid_params): + for j in range(len(tile)): + tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[ + j] + content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ + content += f""" + if(structured_sparsity){{ // SMFMA""" + sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \ + self.config.problem.datatype_map['matrix_b'] == 'fp16' and \ + self.config.problem.datatype_map['matrix_c'] == 'fp16' and \ + ((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or + (warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32)) content += f""" - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" - else: - raise ValueError(f"Invalid tile configuration for group {group}: {tile}") - content += f""" - }} else {{""" - for tile in tile_params: - if self.is_tile_valid(tile, group): + return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(sparse)}>>(args, stream);""" content += f""" - run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" - else: - raise ValueError(f"Invalid tile configuration for group {group}: {tile}") + }} else {{""" + content += f""" + return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(False)}>>(args, stream);""" + content += f""" + }} """ + + if j == len(tile)-1: + content += f""" + }} """ + else: + content += f""" + }}, """ content += f""" - }} - }};\n""" + }};\n """ content += """ } - + template - static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) + static std::tuple run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { + std::string name = Kernel::get_name(); float avg_time = Kernel::launch(args, stream); - std::string description = Kernel::get_name(); - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - std::size_t flop = std::size_t(2) * args.M * args.N * args.K; - std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N; - float tflops = static_cast(flop) / 1.E9 / avg_time; - float gb_per_sec = num_byte / 1.E6 / avg_time; - - std::cout << "Performance for " << description << " : " << avg_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - - if(verify) - compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); + return std::make_tuple(name, avg_time); } - - static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::HostTensor& c_m_n_host_result, - ck_tile::HostTensor& c_m_n_dev_result, - int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, - const ck_tile::stream_config& stream) { + + + static auto dispatch(bool structured_sparsity, const KernelTraits& trait) { init(structured_sparsity); const std::string key = assemble_key(trait); - auto& kernel_map = get_kernel_map(); - if(auto it = kernel_map.find(key); it != kernel_map.end()) { - return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream); + auto& kernel_map = get_kernel_map(); + if(auto it = kernel_map.find(key); it != kernel_map.end()) + { + return it->second; } throw std::runtime_error("No suitable kernel found: " + key); } private: static std::string assemble_key(const KernelTraits &trait) { - return std::string(trait.pipeline) + "_" + - trait.epilogue + "_" + + return std::string(trait.pipeline) + "_" + + trait.epilogue + "_" + trait.scheduler + "_" + - "pad_" + - (trait.kPadM ? "true" : "false") + "_" + - (trait.kPadN ? "true" : "false") + "_" + - (trait.kPadK ? "true" : "false"); + (trait.pad_m ? "true" : "false") + "_" + + (trait.pad_n ? "true" : "false") + "_" + + (trait.pad_k ? "true" : "false"); } }; """ (self.output_dir / "gemm_dispatcher.hpp").write_text(content) - -def do_list_blobs(args, gemm_config): - generator = GemmCodeGenerator(args.working_path, gemm_config) - generator.list_all() -def do_gen_blobs(args, gemm_config): - generator = GemmCodeGenerator(args.working_path, gemm_config) - generator.generate_all() +def do_list_blobs(args: argparse.Namespace, + user_provide_config: Optional[GemmConfig] = None): + generator = GemmCodeGenerator(args.working_path, user_provide_config) + generator.list_all_trait_names() + + +def do_gen_blobs(args: argparse.Namespace, + user_provide_config: Optional[GemmConfig] = None): + generator = GemmCodeGenerator(args.working_path, user_provide_config) + generator.generate_all_instance_files() - def main(args): - # Read json file - with open(args.json, 'r') as json_file: - config_data = json.load(json_file) - - gemm_config = GemmConfig(config_data) + + gemm_config = GemmConfig.from_json( + args.config_json) if args.config_json is not None else args.config_json if args.list_blobs: do_list_blobs(args, gemm_config) elif args.gen_blobs: do_gen_blobs(args, gemm_config) else: - # If neither was specified, either do nothing or default to gen_blobs - print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...") + logging.warning( + "No mode specified (use --list_blobs or --gen_blobs). Generating by default...") do_gen_blobs(args, gemm_config) - if __name__ == "__main__": @@ -696,18 +680,18 @@ if __name__ == "__main__": description="gen API for CK gemm kernel", ) parser.add_argument( - "-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated" + "-w", "--working_path", default="./", required=False, help="The path where all the blobs are going to be generated" ) parser.add_argument( - "-j", "--json", required=True, help="Path to the json which contains the kernel configurations" + "-j", "--config_json", required=False, help="Path to the json which contains the configurations that user provide" ) parser.add_argument( - "-l", "--list_blobs", action = 'store_true', help="List all kernel to file" + "-l", "--list_blobs", action='store_true', help="List all kernel instances to file" ) parser.add_argument( - "-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files" + "-g", "--gen_blobs", action='store_true', help="Generate all kernel instances into different files" ) - + args = parser.parse_args() - + main(args) diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp new file mode 100644 index 0000000000..0125a759b3 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck_tile/host/device_prop.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "benchmark_gemm.hpp" + +class GemmProfiler +{ + public: + static GemmProfiler& instance(Setting setting) + { + static GemmProfiler instance{setting}; + return instance; + } + + void benchmark(GemmProblem& gemm_problem, + std::vector( + ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables) + { + const ALayout layout_a = ALayout{}; + const BLayout layout_b = BLayout{}; + const CLayout layout_c = CLayout{}; + + gemm_problem.stride_a_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)); + gemm_problem.stride_b_ = ck_tile::get_default_stride( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)); + gemm_problem.stride_c_ = ck_tile::get_default_stride( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)); + + ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b))); + ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.init_method_ == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + else if(setting_.init_method_ == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(setting_.init_method_ == 2) + { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(gemm_problem.structured_sparsity_) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + // permute_tensor_b(b_k_n_dev); + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::GemmHostArgs gemm_args; + gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + gemm_args.k_batch = gemm_problem.split_k_; + gemm_args.M = gemm_problem.m_; + gemm_args.N = gemm_problem.n_; + gemm_args.K = gemm_problem.k_; + gemm_args.stride_A = gemm_problem.stride_a_; + gemm_args.stride_B = gemm_problem.stride_b_; + gemm_args.stride_C = gemm_problem.stride_c_; + + ck_tile::HostTensor c_m_n_host_result(ck_tile::host_tensor_descriptor( + gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); + + if(setting_.verify_) + { + gemm_host_reference(setting_.verify_, + a_m_k, + b_k_n, + c_m_n_host_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + gemm_problem.m_, + gemm_problem.n_, + gemm_problem.k_, + gemm_problem.stride_a_, + gemm_problem.stride_b_, + gemm_problem.stride_c_); + } + + for(auto& callable : callables) + { + auto kernel_run_result = callable(gemm_args, + ck_tile::stream_config{nullptr, + true, + setting_.log_, + setting_.n_warmup_, + setting_.n_repeat_, + setting_.is_gpu_timer_, + setting_.flush_cache_, + setting_.rotating_count_}); + process_result(gemm_problem, + c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + kernel_run_result); + } + } + + void process_result(const GemmProblem& gemm_problem, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + const std::tuple& kernel_run_result) + { + auto [name, avg_time] = kernel_run_result; + + KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}}; + + // compute performance metric + std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_; + std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ + + sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ + + sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_; + + // update + kernel_instance.perf_result_.latency_ = avg_time; + kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; + kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; + + if(setting_.log_ > 0) + { + std::cout << kernel_instance << std::endl; + } + + // verify result + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool verified_correct = + !setting_.verify_ || + compare(gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result); + + if(verified_correct) + { + kernel_instances_.emplace_back(kernel_instance); + } + else + { + std::cout << "Verification failed, skip kernel: " << name << std::endl; + } + + // clear tensor + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + } + + KernelInstance select_best_instance(Metric metric) + { + if(kernel_instances_.empty()) + throw std::runtime_error("Empty instances"); + + auto kernel_instance = *std::max_element(kernel_instances_.begin(), + kernel_instances_.end(), + [metric](const auto& a, const auto& b) { + return PerformanceResult::compare( + b.perf_result_, a.perf_result_, metric); + }); + + std::cout << "**********************************" << std::endl; + std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" + << "The best kernel instance is: " << kernel_instance << std::endl; + std::cout << "**********************************" << std::endl; + + if(!setting_.csv_filename_.empty()) + { + std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app); + + if(!file.is_open()) + { + std::cerr << "Warning: Failed to open CSV file for writing." << std::endl; + } + else + { + if(file.tellp() == 0) + { + file << "rocm_version,device_name," + << "split_k,m,n,k,stride_a,stride_b,stride_c," + << "dtype_a,dtype_b,dtype_acc,dtype_c," + << "layout_a,layout_b,layout_c," + << "structured_sparsity," + << "name," + << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n"; + } + + const auto& problem = kernel_instance.problem_; + const auto& name = kernel_instance.name_; + const auto& perf = kernel_instance.perf_result_; + + file << get_rocm_version() << "," << ck_tile::get_device_name() << "," + << problem.split_k_ << "," << problem.m_ << "," << problem.n_ << "," + << problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << "," + << problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_ + << "," << problem.dtype_acc_ << "," << problem.dtype_c_ << "," + << problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_ + << "," << problem.structured_sparsity_ << "," << name << "," << std::fixed + << std::setprecision(4) << perf.latency_ << "," << std::fixed + << std::setprecision(4) << perf.tflops_ << "," << std::fixed + << std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric) + << "\n"; + + if(!file) + { + std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl; + } + } + } + + return kernel_instance; + } + + GemmProfiler(const GemmProfiler&) = delete; + GemmProfiler& operator=(const GemmProfiler&) = delete; + + private: + ~GemmProfiler() { kernel_instances_.clear(); } + GemmProfiler(Setting setting) : setting_(setting) {} + + Setting setting_; + + std::vector kernel_instances_; +}; diff --git a/tile_engine/ops/gemm/json_config.py b/tile_engine/ops/gemm/json_config.py new file mode 100644 index 0000000000..597caba76f --- /dev/null +++ b/tile_engine/ops/gemm/json_config.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +# -*- coding: utf-8 -*- + +""" +Handles loading, parsing, and validation of JSON configuration parameters. +""" + +from pathlib import Path +from dataclasses import dataclass +from typing import List, Optional, Union, Tuple, Type, Dict +import json + + +@dataclass +class EnumConfigParam: + """Represents an enumeration-type configuration parameter""" + values: List[Union[int, str, bool]] + + +@dataclass +class RangeConfigParam: + """Represents a numeric range-type configuration parameter""" + min: int + max: int + step: int + exclude: Optional[List[int]] + + def generate_candidates(self) -> List[int]: + """Generates valid candidates after applying range constraints""" + + if self.min > self.max: + raise ValueError( + f"Invalid range: min({self.min}) > max({self.max})" + ) + if self.step <= 0: + raise ValueError( + f"Step must be positive, got {self.step}" + ) + + candidates = list(range(self.min, self.max + 1, self.step)) + + if hasattr(self, 'exclude') and self.exclude: + if not isinstance(self.exclude, list): + raise TypeError("exclude must be list type") + exclude_set = set(self.exclude) + candidates = [x for x in candidates if x not in exclude_set] + + if not candidates: + raise ValueError( + f"No valid candidates for range [{self.min}-{self.max}] " + f"with step {self.step} and excludes {self.exclude}" + ) + + return candidates + + +@dataclass +class ProblemConfig: + """configuration class for problem parameter.""" + datatypes: Tuple[EnumConfigParam, ...] + layouts: Tuple[EnumConfigParam, ...] + + @property + def datatype_map(self) -> Dict[str, str]: + """Get datatype as a key-value map.""" + return { + 'matrix_a': self.datatypes[0].values[0], + 'matrix_b': self.datatypes[1].values[0], + 'matrix_c': self.datatypes[2].values[0] + } + + @property + def layout_map(self) -> Dict[str, str]: + """Get layout as a key-value map.""" + return { + 'matrix_a': self.layouts[0].values[0], + 'matrix_b': self.layouts[1].values[0], + 'matrix_c': self.layouts[2].values[0] + } + + +@dataclass +class TileConfig: + """Configuration class for tile parameter.""" + tile_m: Union[EnumConfigParam, RangeConfigParam] + tile_n: Union[EnumConfigParam, RangeConfigParam] + tile_k: Union[EnumConfigParam, RangeConfigParam] + + warp_m: Union[EnumConfigParam, RangeConfigParam] + warp_n: Union[EnumConfigParam, RangeConfigParam] + warp_k: Union[EnumConfigParam, RangeConfigParam] + + warp_tile_m: Union[EnumConfigParam, RangeConfigParam] + warp_tile_n: Union[EnumConfigParam, RangeConfigParam] + warp_tile_k: Union[EnumConfigParam, RangeConfigParam] + + +@dataclass +class TraitConfig: + """Configuration class for kernel traits.""" + pipeline: EnumConfigParam + scheduler: EnumConfigParam + epilogue: EnumConfigParam + pad_m: EnumConfigParam + pad_n: EnumConfigParam + pad_k: EnumConfigParam + + +@dataclass +class GemmConfig: + """Main configuration class for GEMM operations """ + problem: ProblemConfig + tile_config: TileConfig + trait_config: TraitConfig + + @classmethod + def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig": + """JSON configuration loader with validation controls""" + config_path = Path(filepath) + + try: + if not config_path.exists(): + raise FileNotFoundError(f"Config file {filepath} not found") + + with config_path.open('r') as f: + config_dict = json.load(f) + + # Parse problem config + problem = ProblemConfig( + datatypes=( + EnumConfigParam( + values=config_dict['problem']['datatype_a']['values']), + EnumConfigParam( + values=config_dict['problem']['datatype_b']['values']), + EnumConfigParam( + values=config_dict['problem']['datatype_c']['values']) + ), + layouts=( + EnumConfigParam( + values=config_dict['problem']['layout_a']['values']), + EnumConfigParam( + values=config_dict['problem']['layout_b']['values']), + EnumConfigParam( + values=config_dict['problem']['layout_c']['values']) + ) + ) + + # Parse tile config + def create_param(param_dict): + if 'values' in param_dict: + return EnumConfigParam(values=param_dict['values']) + else: + return RangeConfigParam( + min=param_dict['min'], + max=param_dict['max'], + step=param_dict['step'], + exclude=param_dict.get('exclude', []) + ) + + tile_config = TileConfig( + tile_m=create_param(config_dict['tile_config']['tile_m']), + tile_n=create_param(config_dict['tile_config']['tile_n']), + tile_k=create_param(config_dict['tile_config']['tile_k']), + warp_m=create_param(config_dict['tile_config']['warp_m']), + warp_n=create_param(config_dict['tile_config']['warp_n']), + warp_k=create_param(config_dict['tile_config']['warp_k']), + warp_tile_m=create_param( + config_dict['tile_config']['warp_tile_m']), + warp_tile_n=create_param( + config_dict['tile_config']['warp_tile_n']), + warp_tile_k=create_param( + config_dict['tile_config']['warp_tile_k']) + ) + + # Parse trait config + trait_config = TraitConfig( + pipeline=EnumConfigParam( + values=config_dict['trait_config']['pipeline']['values']), + scheduler=EnumConfigParam( + values=config_dict['trait_config']['scheduler']['values']), + epilogue=EnumConfigParam( + values=config_dict['trait_config']['epilogue']['values']), + pad_m=EnumConfigParam( + values=config_dict['trait_config']['pad_m']['values']), + pad_n=EnumConfigParam( + values=config_dict['trait_config']['pad_n']['values']), + pad_k=EnumConfigParam( + values=config_dict['trait_config']['pad_k']['values']) + ) + + return cls( + problem=problem, + tile_config=tile_config, + trait_config=trait_config + ) + + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format: {str(e)}") + except KeyError as e: + raise KeyError(f"Missing required configuration field: {str(e)}")