diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000000..5c2b781687 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,3 @@ +CheckOptions: + - key: bugprone-reserved-identifier.AllowedIdentifiers + value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__' diff --git a/CMakeLists.txt b/CMakeLists.txt index 0cf342bb45..306e6ca649 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,9 @@ -cmake_minimum_required(VERSION 2.8.3) -project(modular_convolution) +cmake_minimum_required(VERSION 3.5) +project(composable_kernel) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") -include(TargetFlags) -include(AddKernels) +include(CheckCXXCompilerFlag) ## C++ enable_language(CXX) @@ -39,4 +38,161 @@ link_libraries(${OpenMP_pthread_LIBRARY}) find_package(HIP REQUIRED) message(STATUS "Build with HIP ${hip_VERSION}") +## half +#find_path(HALF_INCLUDE_DIR half.hpp) +message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") + +# CMAKE_CXX_FLAGS +SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") +if(BUILD_DEV) + string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything") +endif() +message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + +## tidy +include(EnableCompilerWarnings) +set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) +if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") + set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) +# Enable tidy on hip +elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU") + set(MIOPEN_TIDY_ERRORS ALL) +endif() + +include(ClangTidy) +enable_clang_tidy( + CHECKS + * + -abseil-* + -android-cloexec-fopen + # Yea we shouldn't be using rand() + -cert-msc30-c + -bugprone-exception-escape + -bugprone-macro-parentheses + -cert-env33-c + -cert-msc32-c + -cert-msc50-cpp + -cert-msc51-cpp + -cert-dcl37-c + -cert-dcl51-cpp + -clang-analyzer-alpha.core.CastToStruct + -clang-analyzer-optin.performance.Padding + -clang-diagnostic-deprecated-declarations + -clang-diagnostic-extern-c-compat + -clang-diagnostic-unused-command-line-argument + -cppcoreguidelines-avoid-c-arrays + -cppcoreguidelines-avoid-magic-numbers + -cppcoreguidelines-explicit-virtual-functions + -cppcoreguidelines-init-variables + -cppcoreguidelines-macro-usage + -cppcoreguidelines-non-private-member-variables-in-classes + -cppcoreguidelines-pro-bounds-array-to-pointer-decay + -cppcoreguidelines-pro-bounds-constant-array-index + -cppcoreguidelines-pro-bounds-pointer-arithmetic + -cppcoreguidelines-pro-type-member-init + -cppcoreguidelines-pro-type-reinterpret-cast + -cppcoreguidelines-pro-type-union-access + -cppcoreguidelines-pro-type-vararg + -cppcoreguidelines-special-member-functions + -fuchsia-* + -google-explicit-constructor + -google-readability-braces-around-statements + -google-readability-todo + -google-runtime-int + -google-runtime-references + -hicpp-vararg + -hicpp-braces-around-statements + -hicpp-explicit-conversions + -hicpp-named-parameter + -hicpp-no-array-decay + # We really shouldn't use bitwise operators with signed integers, but + # opencl leaves us no choice + -hicpp-avoid-c-arrays + -hicpp-signed-bitwise + -hicpp-special-member-functions + -hicpp-uppercase-literal-suffix + -hicpp-use-auto + -hicpp-use-equals-default + -hicpp-use-override + -llvm-header-guard + -llvm-include-order + #-llvmlibc-* + -llvmlibc-restrict-system-libc-headers + -llvmlibc-callee-namespace + -llvmlibc-implementation-in-namespace + -llvm-else-after-return + -llvm-qualified-auto + -misc-misplaced-const + -misc-non-private-member-variables-in-classes + -misc-no-recursion + -modernize-avoid-bind + -modernize-avoid-c-arrays + -modernize-pass-by-value + -modernize-use-auto + -modernize-use-default-member-init + -modernize-use-equals-default + -modernize-use-trailing-return-type + -modernize-use-transparent-functors + -performance-unnecessary-value-param + -readability-braces-around-statements + -readability-else-after-return + # we are not ready to use it, but very useful + -readability-function-cognitive-complexity + -readability-isolate-declaration + -readability-magic-numbers + -readability-named-parameter + -readability-uppercase-literal-suffix + -readability-convert-member-functions-to-static + -readability-qualified-auto + -readability-redundant-string-init + # too many narrowing conversions in our code + -bugprone-narrowing-conversions + -cppcoreguidelines-narrowing-conversions + -altera-struct-pack-align + -cppcoreguidelines-prefer-member-initializer + + ${MIOPEN_TIDY_CHECKS} + ${MIOPEN_TIDY_ERRORS} + HEADER_FILTER + "\.hpp$" + EXTRA_ARGS + -DMIOPEN_USE_CLANG_TIDY +) + +include(CppCheck) +enable_cppcheck( + CHECKS + warning + style + performance + portability + SUPPRESS + ConfigurationNotChecked + constStatement + duplicateCondition + noExplicitConstructor + passedByValue + preprocessorErrorDirective + shadowVariable + unusedFunction + unusedPrivateFunction + unusedStructMember + unmatchedSuppression + FORCE + SOURCES + host/host_tensor/src + host/driver_offline/src + composable_kernel/src/kernel_wrapper + INCLUDE + host/host_tensor/include + host/solver/include + host/driver_offline/include + composable_kernel/include/* + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/include + DEFINE + CPPCHECK=1 + __linux__=1 +) + add_subdirectory(host) diff --git a/README.md b/README.md index 6e6019601a..4f071d5896 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ InLeftPads size 2, {1, 1, } InRightPads size 2, {1, 1, } ConvStrides size 2, {2, 2, } ConvDilations size 2, {1, 1, } -device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw +device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw a_k0_m_k1_grid_desc{216, 256, 8} b_k0_n_k1_grid_desc{216, 165888, 8} c_m_n_grid_desc{ 256, 165888} @@ -100,7 +100,7 @@ InLeftPads size 2, {1, 1, } InRightPads size 2, {1, 1, } ConvStrides size 2, {1, 1, } ConvDilations size 2, {1, 1, } -device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw +device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw a_k0_m_k1_grid_desc{288, 1024, 8} b_k0_n_k1_grid_desc{288, 50176, 8} c_m_n_grid_desc{ 1024, 50176} @@ -122,7 +122,7 @@ InLeftPads size 2, {1, 1, } InRightPads size 2, {1, 1, } ConvStrides size 2, {2, 2, } ConvDilations size 2, {1, 1, } -device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk +device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk a_k0_m_k1_grid_desc{216, 165888, 8} b_k0_n_k1_grid_desc{216, 256, 8} c_m_n_grid_desc{ 165888, 256} @@ -144,7 +144,7 @@ InLeftPads size 2, {1, 1, } InRightPads size 2, {1, 1, } ConvStrides size 2, {1, 1, } ConvDilations size 2, {1, 1, } -device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk +device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk a_k0_m_k1_grid_desc{288, 50176, 8} b_k0_n_k1_grid_desc{288, 1024, 8} c_m_n_grid_desc{ 50176, 1024} @@ -166,7 +166,7 @@ InLeftPads size 2, {1, 1, } InRightPads size 2, {1, 1, } ConvStrides size 2, {1, 1, } ConvDilations size 2, {1, 1, } -device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk +device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk a_k0_m_k1_grid_desc{288, 50176, 8} b_k0_n_k1_grid_desc{288, 1024, 8} c_m_n_grid_desc{ 50176, 1024} diff --git a/cmake/AddKernels.cmake b/cmake/AddKernels.cmake deleted file mode 100644 index 429ecc47a9..0000000000 --- a/cmake/AddKernels.cmake +++ /dev/null @@ -1,40 +0,0 @@ - -function(add_kernels SRC_DIR KERNEL_FILES) - set(INIT_KERNELS_LIST) - set(KERNELS_DECLS) - foreach(KERNEL_FILE ${KERNEL_FILES}) - if("${CMAKE_VERSION}" VERSION_LESS 3.0) - configure_file(${KERNEL_FILE} ${KERNEL_FILE}.delete) - else() - set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${KERNEL_FILE}) - endif() - get_filename_component(BASE_NAME ${KERNEL_FILE} NAME_WE) - string(TOUPPER "${BASE_NAME}" KEY_NAME) - string(MAKE_C_IDENTIFIER "${KEY_NAME}" VAR_NAME) - string(APPEND KERNELS_DECLS "extern const size_t APP_KERNEL_${VAR_NAME}_SIZE;\n") - string(APPEND KERNELS_DECLS "extern const unsigned char APP_KERNEL_${VAR_NAME}[];\n") - list(APPEND INIT_KERNELS_LIST " { \"${KEY_NAME}\", std::string(reinterpret_cast(APP_KERNEL_${VAR_NAME}), APP_KERNEL_${VAR_NAME}_SIZE) }") - endforeach() - string(REPLACE ";" ",\n" INIT_KERNELS "${INIT_KERNELS_LIST}") - configure_file(${SRC_DIR}/kernel.cpp.in ${PROJECT_BINARY_DIR}/kernel.cpp) -endfunction() - -function(add_kernel_includes SRC_DIR KERNEL_FILES) - set(INIT_KERNELS_LIST) - foreach(KERNEL_FILE ${KERNEL_FILES}) - if("${CMAKE_VERSION}" VERSION_LESS 3.0) - configure_file(${KERNEL_FILE} ${KERNEL_FILE}.delete) - else() - set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${KERNEL_FILE}) - endif() - get_filename_component(BASE_NAME ${KERNEL_FILE} NAME_WE) - get_filename_component(FILE_NAME ${KERNEL_FILE} NAME) - string(TOUPPER "${BASE_NAME}" KEY_NAME) - string(MAKE_C_IDENTIFIER "${KEY_NAME}" VAR_NAME) - list(APPEND INIT_KERNELS_LIST " { \"${FILE_NAME}\", std::string(reinterpret_cast(${VAR_NAME}), ${VAR_NAME}_SIZE) }") - endforeach() - string(REPLACE ";" ",\n" INIT_KERNELS "${INIT_KERNELS_LIST}") - configure_file(${SRC_DIR}/kernel_includes.cpp.in ${PROJECT_BINARY_DIR}/kernel_includes.cpp) -endfunction() - - diff --git a/host/online_compile/addkernels/CMakeLists.txt b/cmake/Analyzers.cmake similarity index 90% rename from host/online_compile/addkernels/CMakeLists.txt rename to cmake/Analyzers.cmake index 874cba6a5e..1bf1a52c68 100644 --- a/host/online_compile/addkernels/CMakeLists.txt +++ b/cmake/Analyzers.cmake @@ -24,7 +24,11 @@ # ################################################################################ -set(ADD_KERNELS_SOURCE include_inliner.cpp addkernels.cpp) +if(NOT TARGET analyze) + add_custom_target(analyze) +endif() -add_executable(addkernels EXCLUDE_FROM_ALL ${ADD_KERNELS_SOURCE}) +function(mark_as_analyzer) + add_dependencies(analyze ${ARGN}) +endfunction() diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake new file mode 100644 index 0000000000..01b348c458 --- /dev/null +++ b/cmake/ClangTidy.cmake @@ -0,0 +1,162 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +include(CMakeParseArguments) +include(Analyzers) + +get_filename_component(CLANG_TIDY_EXE_HINT "${CMAKE_CXX_COMPILER}" PATH) + +find_program(CLANG_TIDY_EXE + NAMES + clang-tidy + clang-tidy-5.0 + clang-tidy-4.0 + clang-tidy-3.9 + clang-tidy-3.8 + clang-tidy-3.7 + clang-tidy-3.6 + clang-tidy-3.5 + HINTS + ${CLANG_TIDY_EXE_HINT} + PATH_SUFFIXES + compiler/bin + PATHS + /opt/rocm/llvm/bin + /opt/rocm/hcc + /usr/local/opt/llvm/bin +) + +function(find_clang_tidy_version VAR) + execute_process(COMMAND ${CLANG_TIDY_EXE} -version OUTPUT_VARIABLE VERSION_OUTPUT) + separate_arguments(VERSION_OUTPUT_LIST UNIX_COMMAND "${VERSION_OUTPUT}") + list(FIND VERSION_OUTPUT_LIST "version" VERSION_INDEX) + if(VERSION_INDEX GREATER 0) + math(EXPR VERSION_INDEX "${VERSION_INDEX} + 1") + list(GET VERSION_OUTPUT_LIST ${VERSION_INDEX} VERSION) + set(${VAR} ${VERSION} PARENT_SCOPE) + else() + set(${VAR} "0.0" PARENT_SCOPE) + endif() + +endfunction() + +if( NOT CLANG_TIDY_EXE ) + message( STATUS "Clang tidy not found" ) + set(CLANG_TIDY_VERSION "0.0") +else() + find_clang_tidy_version(CLANG_TIDY_VERSION) + message( STATUS "Clang tidy found: ${CLANG_TIDY_VERSION}") +endif() + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CLANG_TIDY_FIXIT_DIR ${CMAKE_BINARY_DIR}/fixits) +file(MAKE_DIRECTORY ${CLANG_TIDY_FIXIT_DIR}) +set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CLANG_TIDY_FIXIT_DIR}) + +macro(enable_clang_tidy) + set(options ANALYZE_TEMPORARY_DTORS ALL) + set(oneValueArgs HEADER_FILTER) + set(multiValueArgs CHECKS ERRORS EXTRA_ARGS) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + string(REPLACE ";" "," CLANG_TIDY_CHECKS "${PARSE_CHECKS}") + string(REPLACE ";" "," CLANG_TIDY_ERRORS "${PARSE_ERRORS}") + set(CLANG_TIDY_EXTRA_ARGS) + foreach(ARG ${PARSE_EXTRA_ARGS}) + list(APPEND CLANG_TIDY_EXTRA_ARGS "-extra-arg=${ARG}") + endforeach() + + set(CLANG_TIDY_ALL) + if(PARSE_ALL) + set(CLANG_TIDY_ALL ALL) + endif() + + message(STATUS "Clang tidy checks: ${CLANG_TIDY_CHECKS}") + + if (${PARSE_ANALYZE_TEMPORARY_DTORS}) + set(CLANG_TIDY_ANALYZE_TEMPORARY_DTORS "-analyze-temporary-dtors") + endif() + + if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0") + set(CLANG_TIDY_ERRORS_ARG "") + else() + set(CLANG_TIDY_ERRORS_ARG "-warnings-as-errors='${CLANG_TIDY_ERRORS}'") + endif() + + if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0") + set(CLANG_TIDY_QUIET_ARG "") + else() + set(CLANG_TIDY_QUIET_ARG "-quiet") + endif() + + if(PARSE_HEADER_FILTER) + string(REPLACE "$" "$$" CLANG_TIDY_HEADER_FILTER "${PARSE_HEADER_FILTER}") + else() + set(CLANG_TIDY_HEADER_FILTER ".*") + endif() + + set(CLANG_TIDY_COMMAND + ${CLANG_TIDY_EXE} + ${CLANG_TIDY_QUIET_ARG} + -p ${CMAKE_BINARY_DIR} + -checks='${CLANG_TIDY_CHECKS}' + ${CLANG_TIDY_ERRORS_ARG} + ${CLANG_TIDY_EXTRA_ARGS} + ${CLANG_TIDY_ANALYZE_TEMPORARY_DTORS} + -header-filter='${CLANG_TIDY_HEADER_FILTER}' + ) + add_custom_target(tidy ${CLANG_TIDY_ALL}) + mark_as_analyzer(tidy) + add_custom_target(tidy-base) + add_custom_target(tidy-make-fixit-dir COMMAND ${CMAKE_COMMAND} -E make_directory ${CLANG_TIDY_FIXIT_DIR}) + add_custom_target(tidy-rm-fixit-dir COMMAND ${CMAKE_COMMAND} -E remove_directory ${CLANG_TIDY_FIXIT_DIR}) + add_dependencies(tidy-make-fixit-dir tidy-rm-fixit-dir) + add_dependencies(tidy-base tidy-make-fixit-dir) +endmacro() + +function(clang_tidy_check TARGET) + get_target_property(SOURCES ${TARGET} SOURCES) + # TODO: Use generator expressions instead + # COMMAND ${CLANG_TIDY_COMMAND} $ + # COMMAND ${CLANG_TIDY_COMMAND} $, > + foreach(SOURCE ${SOURCES}) + if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS")) + string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file) + set(tidy_target tidy-target-${TARGET}-${tidy_file}) + add_custom_target(${tidy_target} + # for some targets clang-tidy not able to get information from .clang-tidy + DEPENDS ${SOURCE} + COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..." + ) + add_dependencies(${tidy_target} ${TARGET}) + add_dependencies(${tidy_target} tidy-base) + add_dependencies(tidy ${tidy_target}) + endif() + endforeach() +endfunction() + diff --git a/cmake/CppCheck.cmake b/cmake/CppCheck.cmake new file mode 100644 index 0000000000..797dcf4b4d --- /dev/null +++ b/cmake/CppCheck.cmake @@ -0,0 +1,130 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ + +include(CMakeParseArguments) +include(ProcessorCount) +include(Analyzers) + +find_program(CPPCHECK_EXE + NAMES + cppcheck + PATHS + /opt/rocm/bin +) + +ProcessorCount(CPPCHECK_JOBS) + +set(CPPCHECK_BUILD_DIR ${CMAKE_BINARY_DIR}/cppcheck-build) +file(MAKE_DIRECTORY ${CPPCHECK_BUILD_DIR}) +set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CPPCHECK_BUILD_DIR}) + +macro(enable_cppcheck) + set(options FORCE) + set(oneValueArgs) + set(multiValueArgs CHECKS SUPPRESS DEFINE UNDEFINE INCLUDE SOURCES) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + string(REPLACE ";" "," CPPCHECK_CHECKS "${PARSE_CHECKS}") + string(REPLACE ";" "\n" CPPCHECK_SUPPRESS "${PARSE_SUPPRESS};*:/usr/*") + file(WRITE ${CMAKE_BINARY_DIR}/cppcheck-supressions "${CPPCHECK_SUPPRESS}") + set(CPPCHECK_DEFINES) + foreach(DEF ${PARSE_DEFINE}) + set(CPPCHECK_DEFINES "${CPPCHECK_DEFINES} -D${DEF}") + endforeach() + + set(CPPCHECK_UNDEFINES) + foreach(DEF ${PARSE_UNDEFINE}) + set(CPPCHECK_UNDEFINES "${CPPCHECK_UNDEFINES} -U${DEF}") + endforeach() + + set(CPPCHECK_INCLUDES) + foreach(INC ${PARSE_INCLUDE}) + set(CPPCHECK_INCLUDES "${CPPCHECK_INCLUDES} -I${INC}") + endforeach() + + # set(CPPCHECK_FORCE) + set(CPPCHECK_FORCE "--project=${CMAKE_BINARY_DIR}/compile_commands.json") + if(PARSE_FORCE) + set(CPPCHECK_FORCE --force) + endif() + + set(SOURCES) + set(GLOBS) + foreach(SOURCE ${PARSE_SOURCES}) + get_filename_component(ABS_SOURCE ${SOURCE} ABSOLUTE) + if(EXISTS ${ABS_SOURCE}) + if(IS_DIRECTORY ${ABS_SOURCE}) + set(GLOBS "${GLOBS} ${ABS_SOURCE}/*.cpp ${ABS_SOURCE}/*.hpp ${ABS_SOURCE}/*.cxx ${ABS_SOURCE}/*.c ${ABS_SOURCE}/*.h") + else() + set(SOURCES "${SOURCES} ${ABS_SOURCE}") + endif() + else() + set(GLOBS "${GLOBS} ${ABS_SOURCE}") + endif() + endforeach() + + file(WRITE ${CMAKE_BINARY_DIR}/cppcheck.cmake " + file(GLOB_RECURSE GSRCS ${GLOBS}) + set(CPPCHECK_COMMAND + ${CPPCHECK_EXE} + -q + # -v + # --report-progress + ${CPPCHECK_FORCE} + --cppcheck-build-dir=${CPPCHECK_BUILD_DIR} + --platform=native + --template=gcc + --error-exitcode=1 + -j ${CPPCHECK_JOBS} + ${CPPCHECK_DEFINES} + ${CPPCHECK_UNDEFINES} + ${CPPCHECK_INCLUDES} + --enable=${CPPCHECK_CHECKS} + --inline-suppr + --suppressions-list=${CMAKE_BINARY_DIR}/cppcheck-supressions + ${SOURCES} \${GSRCS} + ) + string(REPLACE \";\" \" \" CPPCHECK_SHOW_COMMAND \"\${CPPCHECK_COMMAND}\") + message(\"\${CPPCHECK_SHOW_COMMAND}\") + execute_process( + COMMAND \${CPPCHECK_COMMAND} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE RESULT + ) + if(NOT RESULT EQUAL 0) + message(FATAL_ERROR \"Cppcheck failed\") + endif() +") + + add_custom_target(cppcheck + COMMAND ${CMAKE_COMMAND} -P ${CMAKE_BINARY_DIR}/cppcheck.cmake + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "cppcheck: Running cppcheck..." + ) + mark_as_analyzer(cppcheck) +endmacro() + + diff --git a/cmake/DoxygenDoc.cmake b/cmake/DoxygenDoc.cmake new file mode 100644 index 0000000000..2e3669fcdf --- /dev/null +++ b/cmake/DoxygenDoc.cmake @@ -0,0 +1,355 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +include(CMakeParseArguments) +include(MainDoc) + +find_program(DOXYGEN_EXECUTABLE NAMES doxygen + PATH_SUFFIXES bin + DOC "Doxygen documentation generator" +) +mark_as_advanced(DOXYGEN_EXECUTABLE) + +find_path(DOT_EXECUTABLE NAMES dot + PATH_SUFFIXES bin + DOC "Graphviz" +) +mark_as_advanced(DOT_EXECUTABLE) + +set(DOXYGEN_ARGS +ABBREVIATE_BRIEF +ALIASES +ALLEXTERNALS +ALLOW_UNICODE_NAMES +ALPHABETICAL_INDEX +ALWAYS_DETAILED_SEC +AUTOLINK_SUPPORT +BINARY_TOC +BRIEF_MEMBER_DESC +BUILTIN_STL_SUPPORT +CALLER_GRAPH +CALL_GRAPH +CASE_SENSE_NAMES +CHM_FILE +CHM_INDEX_ENCODING +CITE_BIB_FILES +CLANG_ASSISTED_PARSING +CLANG_OPTIONS +CLASS_DIAGRAMS +CLASS_GRAPH +COLLABORATION_GRAPH +COLS_IN_ALPHA_INDEX +COMPACT_LATEX +COMPACT_RTF +CPP_CLI_SUPPORT +CREATE_SUBDIRS +DIAFILE_DIRS +DIA_PATH +DIRECTORY_GRAPH +DISABLE_INDEX +DISTRIBUTE_GROUP_DOC +DOCBOOK_OUTPUT +DOCBOOK_PROGRAMLISTING +DOCSET_BUNDLE_ID +DOCSET_FEEDNAME +DOCSET_PUBLISHER_ID +DOCSET_PUBLISHER_NAME +DOTFILE_DIRS +DOT_CLEANUP +DOT_FONTNAME +DOT_FONTPATH +DOT_FONTSIZE +DOT_GRAPH_MAX_NODES +DOT_IMAGE_FORMAT +DOT_MULTI_TARGETS +DOT_NUM_THREADS +# DOT_PATH +DOT_TRANSPARENT +DOXYFILE_ENCODING +ECLIPSE_DOC_ID +ENABLED_SECTIONS +ENABLE_PREPROCESSING +ENUM_VALUES_PER_LINE +EXAMPLE_PATH +EXAMPLE_PATTERNS +EXAMPLE_RECURSIVE +EXCLUDE +EXCLUDE_PATTERNS +EXCLUDE_SYMBOLS +EXCLUDE_SYMLINKS +EXPAND_AS_DEFINED +EXPAND_ONLY_PREDEF +EXTENSION_MAPPING +EXTERNAL_GROUPS +EXTERNAL_PAGES +EXTERNAL_SEARCH +EXTERNAL_SEARCH_ID +EXTRACT_ALL +EXTRACT_ANON_NSPACES +EXTRACT_LOCAL_CLASSES +EXTRACT_LOCAL_METHODS +EXTRACT_PACKAGE +EXTRACT_PRIVATE +EXTRACT_STATIC +EXTRA_PACKAGES +EXTRA_SEARCH_MAPPINGS +EXT_LINKS_IN_WINDOW +FILE_PATTERNS +FILE_VERSION_FILTER +FILTER_PATTERNS +FILTER_SOURCE_FILES +FILTER_SOURCE_PATTERNS +FORCE_LOCAL_INCLUDES +FORMULA_FONTSIZE +FORMULA_TRANSPARENT +FULL_PATH_NAMES +GENERATE_AUTOGEN_DEF +GENERATE_BUGLIST +GENERATE_CHI +GENERATE_DEPRECATEDLIST +GENERATE_DOCBOOK +GENERATE_DOCSET +GENERATE_ECLIPSEHELP +GENERATE_HTML +GENERATE_HTMLHELP +GENERATE_LATEX +GENERATE_LEGEND +GENERATE_MAN +GENERATE_PERLMOD +GENERATE_QHP +GENERATE_RTF +GENERATE_TAGFILE +GENERATE_TESTLIST +GENERATE_TODOLIST +GENERATE_TREEVIEW +GENERATE_XML +GRAPHICAL_HIERARCHY +GROUP_GRAPHS +GROUP_NESTED_COMPOUNDS +# HAVE_DOT +HHC_LOCATION +HIDE_COMPOUND_REFERENCE +HIDE_FRIEND_COMPOUNDS +HIDE_IN_BODY_DOCS +HIDE_SCOPE_NAMES +HIDE_UNDOC_CLASSES +HIDE_UNDOC_MEMBERS +HIDE_UNDOC_RELATIONS +HTML_COLORSTYLE_GAMMA +HTML_COLORSTYLE_HUE +HTML_COLORSTYLE_SAT +HTML_DYNAMIC_SECTIONS +HTML_EXTRA_FILES +HTML_EXTRA_STYLESHEET +HTML_FILE_EXTENSION +HTML_FOOTER +HTML_HEADER +HTML_INDEX_NUM_ENTRIES +HTML_OUTPUT +HTML_STYLESHEET +HTML_TIMESTAMP +IDL_PROPERTY_SUPPORT +IGNORE_PREFIX +IMAGE_PATH +INCLUDED_BY_GRAPH +INCLUDE_FILE_PATTERNS +INCLUDE_GRAPH +INCLUDE_PATH +INHERIT_DOCS +INLINE_GROUPED_CLASSES +INLINE_INFO +INLINE_INHERITED_MEMB +INLINE_SIMPLE_STRUCTS +INLINE_SOURCES +INPUT +INPUT_ENCODING +INPUT_FILTER +INTERACTIVE_SVG +INTERNAL_DOCS +JAVADOC_AUTOBRIEF +LATEX_BATCHMODE +LATEX_BIB_STYLE +LATEX_CMD_NAME +LATEX_EXTRA_FILES +LATEX_EXTRA_STYLESHEET +LATEX_FOOTER +LATEX_HEADER +LATEX_HIDE_INDICES +LATEX_OUTPUT +LATEX_SOURCE_CODE +LATEX_TIMESTAMP +LAYOUT_FILE +LOOKUP_CACHE_SIZE +MACRO_EXPANSION +MAKEINDEX_CMD_NAME +MAN_EXTENSION +MAN_LINKS +MAN_OUTPUT +MAN_SUBDIR +MARKDOWN_SUPPORT +MATHJAX_CODEFILE +MATHJAX_EXTENSIONS +MATHJAX_FORMAT +MATHJAX_RELPATH +MAX_DOT_GRAPH_DEPTH +MAX_INITIALIZER_LINES +MSCFILE_DIRS +MSCGEN_PATH +MULTILINE_CPP_IS_BRIEF +OPTIMIZE_FOR_FORTRAN +OPTIMIZE_OUTPUT_FOR_C +OPTIMIZE_OUTPUT_JAVA +OPTIMIZE_OUTPUT_VHDL +OUTPUT_DIRECTORY +OUTPUT_LANGUAGE +PAPER_TYPE +PDF_HYPERLINKS +PERLMOD_LATEX +PERLMOD_MAKEVAR_PREFIX +PERLMOD_PRETTY +PERL_PATH +PLANTUML_CFG_FILE +PLANTUML_INCLUDE_PATH +PLANTUML_JAR_PATH +PREDEFINED +PROJECT_BRIEF +PROJECT_LOGO +PROJECT_NAME +PROJECT_NUMBER +QCH_FILE +QHG_LOCATION +QHP_CUST_FILTER_ATTRS +QHP_CUST_FILTER_NAME +QHP_NAMESPACE +QHP_SECT_FILTER_ATTRS +QHP_VIRTUAL_FOLDER +QT_AUTOBRIEF +QUIET +RECURSIVE +REFERENCED_BY_RELATION +REFERENCES_LINK_SOURCE +REFERENCES_RELATION +REPEAT_BRIEF +RTF_EXTENSIONS_FILE +RTF_HYPERLINKS +RTF_OUTPUT +RTF_SOURCE_CODE +RTF_STYLESHEET_FILE +SEARCHDATA_FILE +SEARCHENGINE +SEARCHENGINE_URL +SEARCH_INCLUDES +SEPARATE_MEMBER_PAGES +SERVER_BASED_SEARCH +SHORT_NAMES +SHOW_FILES +SHOW_GROUPED_MEMB_INC +SHOW_INCLUDE_FILES +SHOW_NAMESPACES +SHOW_USED_FILES +SIP_SUPPORT +SKIP_FUNCTION_MACROS +SORT_BRIEF_DOCS +SORT_BY_SCOPE_NAME +SORT_GROUP_NAMES +SORT_MEMBERS_CTORS_1ST +SORT_MEMBER_DOCS +SOURCE_BROWSER +SOURCE_TOOLTIPS +STRICT_PROTO_MATCHING +STRIP_CODE_COMMENTS +STRIP_FROM_INC_PATH +STRIP_FROM_PATH +SUBGROUPING +TAB_SIZE +TAGFILES +TCL_SUBST +TEMPLATE_RELATIONS +TOC_EXPAND +TOC_INCLUDE_HEADINGS +TREEVIEW_WIDTH +TYPEDEF_HIDES_STRUCT +UML_LIMIT_NUM_FIELDS +UML_LOOK +USE_HTAGS +USE_MATHJAX +USE_MDFILE_AS_MAINPAGE +USE_PDFLATEX +VERBATIM_HEADERS +WARNINGS +WARN_AS_ERROR +WARN_FORMAT +WARN_IF_DOC_ERROR +WARN_IF_UNDOCUMENTED +WARN_LOGFILE +WARN_NO_PARAMDOC +XML_OUTPUT +XML_PROGRAMLISTING +) + +set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") + +function(add_doxygen_doc) + set(options) + set(oneValueArgs) + set(multiValueArgs DEPENDS ${DOXYGEN_ARGS}) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n") + + foreach(ARG ${DOXYGEN_ARGS}) + if(PARSE_${ARG}) + string(REPLACE ";" " " ARG_VALUE ${PARSE_${ARG}}) + file(APPEND ${DOXYGEN_CONFIG_FILE} "\n${ARG} = ${ARG_VALUE}\n") + endif() + endforeach() + + if(PARSE_OUTPUT_DIRECTORY) + if(NOT EXISTS ${PARSE_OUTPUT_DIRECTORY}) + file(MAKE_DIRECTORY ${PARSE_OUTPUT_DIRECTORY}) + endif() + endif() + + if(DOT_EXECUTABLE) + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nDOT_PATH = \"${DOT_EXECUTABLE}\"\n") + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = YES\n") + else() + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = NO\n") + endif() + + add_custom_target(doxygen + ${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Building documentation with doxygen" + ) + if(PARSE_OUTPUT_DIRECTORY) + clean_doc_output(${PARSE_OUTPUT_DIRECTORY}) + endif() + mark_as_doc(doxygen) + if(PARSE_DEPENDS) + add_dependencies(doxygen ${PARSE_DEPENDS}) + endif() +endfunction() diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake new file mode 100644 index 0000000000..9f193b2090 --- /dev/null +++ b/cmake/EnableCompilerWarnings.cmake @@ -0,0 +1,110 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +# - Enable warning all for gcc/clang or use /W4 for visual studio + +## Strict warning level +if (MSVC) + # Use the highest warning level for visual studio. + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w") + # set(CMAKE_CXX_WARNING_LEVEL 4) + # if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]") + # string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + # else () + # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") + # endif () + + # set(CMAKE_C_WARNING_LEVEL 4) + # if (CMAKE_C_FLAGS MATCHES "/W[0-4]") + # string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") + # else () + # set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4") + # endif () + +else() + foreach(COMPILER C CXX) + set(CMAKE_COMPILER_WARNINGS) + # use -Wall for gcc and clang + list(APPEND CMAKE_COMPILER_WARNINGS + -Wall + -Wextra + -Wcomment + -Wendif-labels + -Wformat + -Winit-self + -Wreturn-type + -Wsequence-point + # Shadow is broken on gcc when using lambdas + # -Wshadow + -Wswitch + -Wtrigraphs + -Wundef + -Wuninitialized + -Wunreachable-code + -Wunused + + -Wno-sign-compare + -Wno-extra-semi-stmt + ) + if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") + list(APPEND CMAKE_COMPILER_WARNINGS + -Weverything + -Wno-c++98-compat + -Wno-c++98-compat-pedantic + -Wno-conversion + -Wno-double-promotion + -Wno-exit-time-destructors + -Wno-extra-semi + -Wno-float-conversion + -Wno-gnu-anonymous-struct + -Wno-gnu-zero-variadic-macro-arguments + -Wno-missing-prototypes + -Wno-nested-anon-types + -Wno-padded + -Wno-return-std-move-in-c++11 + -Wno-shorten-64-to-32 + -Wno-sign-conversion + -Wno-unknown-warning-option + -Wno-unused-command-line-argument + -Wno-weak-vtables + -Wno-covered-switch-default + ) + else() + if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") + # cmake 3.5.2 does not support >=. + if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.1") + list(APPEND CMAKE_COMPILER_WARNINGS + -Wno-ignored-attributes) + endif() + endif() + list(APPEND CMAKE_COMPILER_WARNINGS + -Wno-missing-field-initializers + -Wno-deprecated-declarations + ) + endif() + add_definitions(${CMAKE_COMPILER_WARNINGS}) + endforeach() +endif () diff --git a/cmake/TargetFlags.cmake b/cmake/TargetFlags.cmake deleted file mode 100644 index 4f83fb5d39..0000000000 --- a/cmake/TargetFlags.cmake +++ /dev/null @@ -1,50 +0,0 @@ - -function(get_target_property2 VAR TARGET PROPERTY) - get_target_property(_pflags ${TARGET} ${PROPERTY}) - if(_pflags) - set(${VAR} ${_pflags} PARENT_SCOPE) - else() - set(${VAR} "" PARENT_SCOPE) - endif() -endfunction() - - -macro(append_flags FLAGS TARGET PROPERTY PREFIX) - get_target_property2(_pflags ${TARGET} ${PROPERTY}) - foreach(FLAG ${_pflags}) - if(TARGET ${FLAG}) - target_flags(_pflags2 ${FLAG}) - string(APPEND ${FLAGS} " ${_pflags2}") - else() - string(APPEND ${FLAGS} " ${PREFIX}${FLAG}") - endif() - endforeach() -endmacro() - -macro(append_link_flags FLAGS TARGET PROPERTY) - get_target_property2(_pflags ${TARGET} ${PROPERTY}) - foreach(FLAG ${_pflags}) - if(TARGET ${FLAG}) - target_flags(_pflags2 ${FLAG}) - string(APPEND ${FLAGS} " ${_pflags2}") - elseif(FLAG MATCHES "^-.*") - string(APPEND ${FLAGS} " ${FLAG}") - elseif(EXISTS ${FLAG}) - string(APPEND ${FLAGS} " ${FLAG}") - else() - string(APPEND ${FLAGS} " -l${FLAG}") - endif() - endforeach() -endmacro() - -function(target_flags FLAGS TARGET) - set(_flags) - append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "") - append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D") - append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ") - append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ") - append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "") - append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "") - # message("_flags: ${_flags}") - set(${FLAGS} ${_flags} PARENT_SCOPE) -endfunction() diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp index 5c582dea46..09ea16fa23 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp @@ -2,8 +2,8 @@ #define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -23,9 +23,9 @@ template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( - const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -102,7 +102,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( const auto K0 = K / K1; // weight tensor - const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( wei_k_y_x_c_grid_desc, make_tuple(make_pass_through_transform(K), make_embed_transform(make_tuple(YDot, YTilda), @@ -114,28 +114,28 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilda), - make_freeze_transform(IXTilda), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); + transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilda), + make_freeze_transform(IXTilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); #if 1 - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( wei_k0_k1_ydotslice_xdotslice_c_grid_desc, make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_pass_through_transform(C), @@ -143,7 +143,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #else - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( wei_k0_k1_ydotslice_xdotslice_c_grid_desc, make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_pass_through_transform(C), @@ -154,7 +154,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( // output tensor // this add padding check - const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( out_n_ho_wo_k_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Ho, I0, I0), @@ -163,7 +163,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(YDot, HTilda), @@ -175,7 +175,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = - transform_dynamic_tensor_descriptor( + transform_tensor_descriptor( out_n_ydot_htilda_xdot_wtilda_k_grid_desc, make_tuple(make_pass_through_transform(N), make_slice_transform(YDot, I0, YDotSlice), @@ -197,7 +197,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( Sequence<5, 6>{})); #if 1 - const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), @@ -205,7 +205,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #else - const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), @@ -215,7 +215,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( #endif // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), @@ -224,7 +224,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(YTilda, HTilda), @@ -235,7 +235,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, make_tuple(make_pass_through_transform(N), make_freeze_transform(IYTilda), @@ -256,7 +256,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( Sequence<2>{}, Sequence<3>{})); - const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( in_n_htildaslice_wtildaslice_c_grid_desc, make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))), diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp index 377a1ac29b..9c60e8c3ac 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -2,8 +2,8 @@ #define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -26,9 +26,9 @@ template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( - const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, - const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -106,7 +106,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( // A: output tensor // this add padding check - const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( out_n_ho_wo_k_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Ho, I0, I0), @@ -115,7 +115,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( out_n_hop_wop_k_grid_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(YDot, HTilda), @@ -127,7 +127,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = - transform_dynamic_tensor_descriptor( + transform_tensor_descriptor( out_n_ydot_htilda_xdot_wtilda_k_grid_desc, make_tuple(make_pass_through_transform(N), make_slice_transform(YDot, I0, YDotSlice), @@ -149,7 +149,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( Sequence<5, 6>{})); #if 1 - const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), @@ -157,7 +157,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #else - const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), @@ -167,7 +167,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( #endif // B: weight tensor - const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( wei_k_y_x_c_grid_desc, make_tuple(make_pass_through_transform(K), make_embed_transform(make_tuple(YDot, YTilda), @@ -179,28 +179,28 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilda), - make_freeze_transform(IXTilda), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); + transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilda), + make_freeze_transform(IXTilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); #if 1 - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( wei_k0_k1_ydotslice_xdotslice_c_grid_desc, make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), make_pass_through_transform(C), @@ -208,7 +208,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); #else - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( wei_k0_k1_ydotslice_xdotslice_c_grid_desc, make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), make_pass_through_transform(C), @@ -218,7 +218,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( #endif // C: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), @@ -227,7 +227,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(YTilda, HTilda), @@ -238,7 +238,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, make_tuple(make_pass_through_transform(N), make_freeze_transform(IYTilda), @@ -259,7 +259,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( Sequence<2>{}, Sequence<3>{})); - const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( in_n_htildaslice_wtildaslice_c_grid_desc, make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), make_pass_through_transform(C)), diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp index 404129365f..093a46256d 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -2,8 +2,8 @@ #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -18,9 +18,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( - const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ const auto InRightPadW = in_right_pads[I1]; // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, make_tuple(make_pass_through_transform(N), make_pass_through_transform(C), @@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_hip_wip_global_desc, make_tuple(make_pass_through_transform(N), make_pass_through_transform(C), @@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); const auto in_gemmk_gemmn_global_desc = - transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -109,9 +109,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( - const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -126,9 +126,6 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); @@ -150,14 +147,14 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0); // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, make_tuple(make_pass_through_transform(N), make_pass_through_transform(C), @@ -167,15 +164,15 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); const auto in_gemmk_gemmn_global_desc = - transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); @@ -192,9 +189,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1( - const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -209,9 +206,6 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); @@ -235,22 +229,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ InRightPadW == 0); // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp index 79051d9512..9aa27884da 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -2,8 +2,8 @@ #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -18,9 +18,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( - const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ const auto InRightPadW = in_right_pads[I1]; // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), @@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), @@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto in_gemmk_gemmn_grid_desc = - transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); @@ -108,9 +108,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( - const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -125,9 +125,6 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); @@ -151,22 +148,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ InRightPadW == 0); // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), + const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp index 49ae26518e..16ae8b470d 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -2,8 +2,8 @@ #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -20,9 +20,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( const auto GemmK0 = GemmK / GemmK1; // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( - wei_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( in_n_c_hi_wi_grid_desc, make_tuple(make_pass_through_transform(N), make_pass_through_transform(C), @@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( in_n_c_hip_wip_grid_desc, make_tuple(make_pass_through_transform(N), make_pass_through_transform(C), @@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); const auto in_gemmk_gemmn_grid_desc = - transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( - in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), make_tuple(Sequence<1>{}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp index 5814e66766..e81c87d046 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -2,8 +2,8 @@ #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -20,9 +20,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( - const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( const auto GemmK0 = GemmK / GemmK1; // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( - wei_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), @@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), @@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto in_gemmk_gemmn_grid_desc = - transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( - in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp index ad9d99f4e7..b0b07505e5 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -2,8 +2,8 @@ #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -23,9 +23,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( - const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, - const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -70,7 +70,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( const auto GemmK0 = GemmK / GemmK1; // A: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), @@ -79,7 +79,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), @@ -89,36 +89,36 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); const auto in_gemmk_gemmm_grid_desc = - transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( - in_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // B: weight tensor - const auto wei_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( - wei_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); // C: output tensor - const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp index e709f768cb..f5cb7f4877 100644 --- a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -2,8 +2,8 @@ #define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -24,9 +24,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( - const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -68,15 +68,15 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( const auto C1 = C / C0; // weight tensor - const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), - make_tuple(make_unmerge_transform(make_tuple(I1, K)), - make_unmerge_transform(make_tuple(C0, C1 * Y * X))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); + const auto wei_gk0_gm0_gm1_gk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_unmerge_transform(make_tuple(I1, K)), + make_unmerge_transform(make_tuple(C0, C1 * Y * X))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( in_n_c_hi_wi_grid_desc, make_tuple(make_pass_through_transform(N), make_pass_through_transform(C), @@ -85,7 +85,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor( in_n_c_hip_wip_grid_desc, make_tuple(make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(C0, C1)), @@ -94,7 +94,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); - const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor( + const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor( in_n0_n1_c0_c1_y_ho_x_wo_grid_desc, make_tuple(make_merge_transform(make_tuple(C1, Y, X)), make_pass_through_transform(N0), @@ -105,17 +105,17 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( // output tensor const auto out_n_k_howo_grid_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)); + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)); - const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( - out_n_k_howo_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(N0, N1)), - make_unmerge_transform(make_tuple(I1, K)), - make_pass_through_transform(Ho * Wo)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); + const auto out_n0_n1_1_k_howo_grid_desc = + transform_tensor_descriptor(out_n_k_howo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_unmerge_transform(make_tuple(I1, K)), + make_pass_through_transform(Ho * Wo)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); - const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( + const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor( out_n0_n1_1_k_howo_grid_desc, make_tuple(make_pass_through_transform(I1), make_pass_through_transform(K), diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp similarity index 91% rename from composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp rename to composable_kernel/include/tensor_description/multi_index_transform.hpp index 967517bef7..a33b9aee8d 100644 --- a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ -#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP -#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP +#ifndef CK_MULTI_INDEX_TRANSFORM_HPP +#define CK_MULTI_INDEX_TRANSFORM_HPP #include "common_header.hpp" #include "multi_index.hpp" @@ -7,7 +7,7 @@ namespace ck { template -struct DynamicPassThrough +struct PassThrough { using LowerIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>; @@ -16,9 +16,9 @@ struct DynamicPassThrough UpLengths up_lengths_; - __host__ __device__ constexpr DynamicPassThrough() = default; + __host__ __device__ constexpr PassThrough() = default; - __host__ __device__ constexpr DynamicPassThrough(const LowLength& low_length) + __host__ __device__ constexpr PassThrough(const LowLength& low_length) : up_lengths_{make_tuple(low_length)} { } @@ -82,33 +82,36 @@ struct DynamicPassThrough __host__ __device__ void Print() const { printf("{"); - printf("DynamicPassThrough, "); + printf("PassThrough, "); printf("up_lengths_"); print_multi_index(up_lengths_); printf("}"); } }; -template -struct DynamicPad +template +struct Pad { using LowerIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>; - using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{} + RightPad{})); + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{})); UpLengths up_lengths_; - LeftPad left_pad_; - RightPad right_pad_; + LeftPadLength left_pad_length_; + RightPadLength right_pad_length_; - __host__ __device__ constexpr DynamicPad() = default; + __host__ __device__ constexpr Pad() = default; - __host__ __device__ constexpr DynamicPad(const LowLength& low_length, - const LeftPad& left_pad, - const RightPad& right_pad) - : up_lengths_{make_tuple(low_length + left_pad + right_pad)}, - left_pad_{left_pad}, - right_pad_{right_pad} + __host__ __device__ constexpr Pad(const LowLength& low_length, + const LeftPadLength& left_pad_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)}, + left_pad_length_{left_pad_length}, + right_pad_length_{right_pad_length} { } @@ -125,7 +128,7 @@ struct DynamicPad static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, "wrong! inconsistent # of dimension"); - idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_; + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; } template {}] >= left_pad_) && - (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_)); + return SkipIsValidCheck || + ((idx_up[Number<0>{}] >= left_pad_length_) && + (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_)); } __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return is_known_at_compile_time::value && - is_known_at_compile_time::value && - is_known_at_compile_time::value; + is_known_at_compile_time::value && + is_known_at_compile_time::value; } __host__ __device__ void Print() const { printf("{"); - printf("DynamicPad, "); + printf("Pad, "); printf("up_lengths_"); print_multi_index(up_lengths_); - printf("left_pad_ %d", index_t{left_pad_}); - printf("right_pad_ %d", index_t{right_pad_}); + printf("left_pad_length %d", index_t{left_pad_length_}); + printf("right_pad_length %d", index_t{right_pad_length_}); printf("}"); } }; -template -struct DynamicLeftPad +template +struct LeftPad { using LowerIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>; - using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{})); + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{})); UpLengths up_lengths_; - LeftPad left_pad_; + LeftPadLength left_pad_length_; - __host__ __device__ constexpr DynamicLeftPad() = default; + __host__ __device__ constexpr LeftPad() = default; - __host__ __device__ constexpr DynamicLeftPad(const LowLength& low_length, - const LeftPad& left_pad) - : up_lengths_{make_tuple(low_length + left_pad)}, left_pad_{left_pad} + __host__ __device__ constexpr LeftPad(const LowLength& low_length, + const LeftPadLength& left_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length} { } @@ -216,7 +220,7 @@ struct DynamicLeftPad static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, "wrong! inconsistent # of dimension"); - idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_; + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; } template {}] >= left_pad_); + return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_); } __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return is_known_at_compile_time::value && - is_known_at_compile_time::value; + is_known_at_compile_time::value; } __host__ __device__ void Print() const { printf("{"); - printf("DynamicLeftPad, "); + printf("LeftPad, "); printf("up_lengths_"); print_multi_index(up_lengths_); - printf("left_pad_ %d", index_t{left_pad_}); + printf("left_pad_length_ %d", index_t{left_pad_length_}); printf("}"); } }; -template -struct DynamicRightPad +template +struct RightPad { using LowerIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>; - using UpLengths = decltype(make_tuple(LowLength{} + RightPad{})); + using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{})); UpLengths up_lengths_; LowLength low_length_; - RightPad right_pad_; + RightPadLength right_pad_length_; - __host__ __device__ constexpr DynamicRightPad() = default; + __host__ __device__ constexpr RightPad() = default; - __host__ __device__ constexpr DynamicRightPad(const LowLength& low_length, - const RightPad& right_pad) - : up_lengths_{make_tuple(low_length + right_pad)}, + __host__ __device__ constexpr RightPad(const LowLength& low_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + right_pad_length)}, low_length_{low_length}, - right_pad_{right_pad} + right_pad_length_{right_pad_length} { } @@ -350,17 +354,17 @@ struct DynamicRightPad { return is_known_at_compile_time::value && is_known_at_compile_time::value && - is_known_at_compile_time::value; + is_known_at_compile_time::value; } __host__ __device__ void Print() const { printf("{"); - printf("DynamicRightPad, "); + printf("RightPad, "); printf("up_lengths_"); print_multi_index(up_lengths_); printf("low_length_ %d", index_t{low_length_}); - printf("left_pad_ %d", index_t{right_pad_}); + printf("left_pad_length_ %d", index_t{right_pad_length_}); printf("}"); } }; @@ -373,8 +377,8 @@ struct DynamicRightPad // at compile-time template ::type = false> -struct DynamicEmbed + typename enable_if::type = false> +struct Embed { static constexpr index_t NDimUp = UpLengths::Size(); @@ -384,10 +388,10 @@ struct DynamicEmbed UpLengths up_lengths_; Coefficients coefficients_; - __host__ __device__ constexpr DynamicEmbed() = default; + __host__ __device__ constexpr Embed() = default; - __host__ __device__ constexpr DynamicEmbed(const UpLengths& up_lengths, - const Coefficients& coefficients) + __host__ __device__ constexpr Embed(const UpLengths& up_lengths, + const Coefficients& coefficients) : up_lengths_{up_lengths}, coefficients_{coefficients} { } @@ -458,7 +462,7 @@ struct DynamicEmbed __host__ __device__ void Print() const { printf("{"); - printf("DynamicEmbed, "); + printf("Embed, "); printf("up_lengths_ "); print_multi_index(up_lengths_); printf("coefficients_ "); @@ -470,7 +474,7 @@ struct DynamicEmbed // Implementation of "Merge" transformation primitive that uses regular to do lowering of // multi-index and use carry-and-borrow check to do lowering of multi-index delta template -struct DynamicMerge_v1_carry_check +struct Merge_v1_carry_check { static constexpr index_t NDimLow = LowLengths::Size(); @@ -487,9 +491,9 @@ struct DynamicMerge_v1_carry_check LowLengthsScan low_lengths_scan_; UpLengths up_lengths_; - __host__ __device__ constexpr DynamicMerge_v1_carry_check() = default; + __host__ __device__ constexpr Merge_v1_carry_check() = default; - __host__ __device__ constexpr DynamicMerge_v1_carry_check(const LowLengths& low_lengths) + __host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths) : low_lengths_{low_lengths}, low_lengths_scan_{ container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, @@ -555,7 +559,7 @@ struct DynamicMerge_v1_carry_check LowerIndex idx_low_length_minus_idx_diff_low_const; LowerIndex idx_low_length_plus_idx_diff_low_const; -#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE index_t tmp = idx_diff_up[Number<0>{}]; static_for<0, NDimLow - 1, 1>{}([&](auto i) { @@ -698,7 +702,7 @@ struct DynamicMerge_v1_carry_check LowerIndex idx_low_length_minus_idx_diff_low_const; LowerIndex idx_low_length_plus_idx_diff_low_const; -#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE index_t tmp = idx_diff_up[Number<0>{}]; static_for<0, NDimLow - 1, 1>{}([&](auto i) { @@ -838,7 +842,7 @@ struct DynamicMerge_v1_carry_check // very expensive. LowerIndex idx_diff_low_const; -#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE index_t tmp = idx_diff_up[Number<0>{}]; static_for<0, NDimLow - 1, 1>{}([&](auto i) { @@ -981,7 +985,7 @@ struct DynamicMerge_v1_carry_check __host__ __device__ void Print() const { printf("{"); - printf("DynamicMerge_v1_carry_check, "); + printf("Merge_v1_carry_check, "); printf("low_lengths_ "); print_multi_index(low_lengths_); printf("low_lengths_scan_ "); @@ -1025,7 +1029,7 @@ struct lambda_merge_generate_MagicDivision_calculate_magic_shift // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be // non-negative. template -struct DynamicMerge_v2_magic_division +struct Merge_v2_magic_division { static constexpr index_t NDimLow = LowLengths::Size(); @@ -1048,9 +1052,9 @@ struct DynamicMerge_v2_magic_division LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_; UpLengths up_lengths_; - __host__ __device__ constexpr DynamicMerge_v2_magic_division() = default; + __host__ __device__ constexpr Merge_v2_magic_division() = default; - __host__ __device__ constexpr DynamicMerge_v2_magic_division(const LowLengths& low_lengths) + __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths) : low_lengths_{low_lengths}, low_lengths_magic_divisor_multiplier_{generate_tuple( [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); }, @@ -1151,7 +1155,7 @@ struct DynamicMerge_v2_magic_division __host__ __device__ void Print() const { printf("{"); - printf("DynamicMerge_v2_magic_division, "); + printf("Merge_v2_magic_division, "); printf("low_lengths_ "); print_multi_index(low_lengths_); printf("low_lengths_magic_divisor_multiplier_ "); @@ -1177,7 +1181,7 @@ struct DynamicMerge_v2_magic_division // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be // non-negative. template -struct DynamicMerge_v2r2_magic_division +struct Merge_v2r2_magic_division { static constexpr index_t NDimLow = LowLengths::Size(); @@ -1204,9 +1208,9 @@ struct DynamicMerge_v2r2_magic_division LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_; UpLengths up_lengths_; - __host__ __device__ constexpr DynamicMerge_v2r2_magic_division() = default; + __host__ __device__ constexpr Merge_v2r2_magic_division() = default; - __host__ __device__ constexpr DynamicMerge_v2r2_magic_division(const LowLengths& low_lengths) + __host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths) : low_lengths_{low_lengths}, low_lengths_scan_{ container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, @@ -1308,7 +1312,7 @@ struct DynamicMerge_v2r2_magic_division __host__ __device__ void Print() const { printf("{"); - printf("DynamicMerge_v2r2_magic_division, "); + printf("Merge_v2r2_magic_division, "); printf("low_lengths_ "); print_multi_index(low_lengths_); printf("low_lengths_scan "); @@ -1324,7 +1328,7 @@ struct DynamicMerge_v2r2_magic_division }; template -struct DynamicUnMerge +struct UnMerge { static constexpr index_t NDimUp = UpLengths::Size(); @@ -1337,9 +1341,9 @@ struct DynamicUnMerge UpLengths up_lengths_; UpLengthsScan up_lengths_scan_; - __host__ __device__ constexpr DynamicUnMerge() = default; + __host__ __device__ constexpr UnMerge() = default; - __host__ __device__ constexpr DynamicUnMerge(const UpLengths& up_lengths) + __host__ __device__ constexpr UnMerge(const UpLengths& up_lengths) : up_lengths_{up_lengths}, up_lengths_scan_{ container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})} @@ -1414,7 +1418,7 @@ struct DynamicUnMerge __host__ __device__ void Print() const { printf("{"); - printf("DynamicUnMerge, "); + printf("UnMerge, "); printf("up_lengths_"); print_multi_index(up_lengths_); printf("up_lengths_scan_"); @@ -1424,13 +1428,13 @@ struct DynamicUnMerge }; template -struct DynamicFreeze +struct Freeze { LowerIndex low_idx_; - __host__ __device__ constexpr DynamicFreeze() = default; + __host__ __device__ constexpr Freeze() = default; - __host__ __device__ constexpr DynamicFreeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} + __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } @@ -1483,22 +1487,22 @@ struct DynamicFreeze __host__ __device__ void Print() const { - printf("DynamicFreeze"); + printf("Freeze"); printf("low_idx_ %d", index_t{low_idx_}); } }; // Insert a dangling upper dimension without lower dimension template -struct DynamicInsert +struct Insert { using UpLengths = decltype(make_tuple(UpperLength{})); UpLengths up_lengths_; - __host__ __device__ constexpr DynamicInsert() = default; + __host__ __device__ constexpr Insert() = default; - __host__ __device__ constexpr DynamicInsert(const UpperLength& up_length) + __host__ __device__ constexpr Insert(const UpperLength& up_length) : up_lengths_{make_tuple(up_length)} { } @@ -1550,13 +1554,13 @@ struct DynamicInsert __host__ __device__ void Print() const { - printf("DynamicInsert"); + printf("Insert"); print_multi_index(up_lengths_); } }; template -struct DynamicVectorize +struct Vectorize { using LowerIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>; @@ -1566,10 +1570,10 @@ struct DynamicVectorize UpLengths up_lengths_; VectorSize vector_size_; - __host__ __device__ constexpr DynamicVectorize() = default; + __host__ __device__ constexpr Vectorize() = default; - __host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size, - const UpLength& up_length) + __host__ __device__ constexpr Vectorize(const VectorSize& vector_size, + const UpLength& up_length) : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)} { } @@ -1633,7 +1637,7 @@ struct DynamicVectorize __host__ __device__ void Print() const { printf("{"); - printf("DynamicVectorize, "); + printf("Vectorize, "); printf("up_lengths_"); print_multi_index(up_lengths_); printf("}"); @@ -1641,7 +1645,7 @@ struct DynamicVectorize }; template -struct DynamicSlice +struct Slice { using LowerIndex = MultiIndex<1>; using UpperIndex = MultiIndex<1>; @@ -1652,11 +1656,11 @@ struct DynamicSlice SliceBegin slice_begin_; SliceEnd slice_end_; - __host__ __device__ constexpr DynamicSlice() = default; + __host__ __device__ constexpr Slice() = default; - __host__ __device__ constexpr DynamicSlice(const LowLength&, - const SliceBegin& slice_begin, - const SliceEnd& slice_end) + __host__ __device__ constexpr Slice(const LowLength&, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) : up_lengths_{make_tuple(slice_end - slice_begin)}, slice_begin_{slice_begin}, slice_end_{slice_end} @@ -1724,7 +1728,7 @@ struct DynamicSlice __host__ __device__ void Print() const { printf("{"); - printf("DynamicSlice, "); + printf("Slice, "); printf("up_lengths_"); print_multi_index(up_lengths_); printf("slice_begin_ %d", index_t{slice_begin_}); diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp similarity index 63% rename from composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp rename to composable_kernel/include/tensor_description/multi_index_transform_helper.hpp index b3e1c60485..6d4e01888b 100644 --- a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp @@ -1,15 +1,15 @@ -#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP -#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP +#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP +#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP #include "common_header.hpp" -#include "dynamic_multi_index_transform.hpp" +#include "multi_index_transform.hpp" namespace ck { template __host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length) { - return DynamicPassThrough{low_length}; + return PassThrough{low_length}; } template @@ -19,47 +19,46 @@ make_pad_transform(const LowLength& low_length, const RightPad& right_pad, integral_constant = integral_constant{}) { - return DynamicPad{ - low_length, left_pad, right_pad}; + return Pad{low_length, left_pad, right_pad}; } -template +template __host__ __device__ constexpr auto make_left_pad_transform( const LowLength& low_length, - const LeftPad& left_pad, + const LeftPadLength& left_pad, integral_constant = integral_constant{}) { - return DynamicLeftPad{low_length, left_pad}; + return LeftPad{low_length, left_pad}; } -template +template __host__ __device__ constexpr auto make_right_pad_transform( const LowLength& low_length, - const RightPad& right_pad, + const RightPadLength& right_pad, integral_constant = integral_constant{}) { - return DynamicRightPad{low_length, right_pad}; + return RightPad{low_length, right_pad}; } template ::type = false> + typename enable_if::type = false> __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, const Coefficients& coefficients) { - return DynamicEmbed{up_lengths, coefficients}; + return Embed{up_lengths, coefficients}; } template __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) { #if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION - return DynamicMerge_v1_carry_check{low_lengths}; + return Merge_v1_carry_check{low_lengths}; #else #if 1 - return DynamicMerge_v2_magic_division{low_lengths}; + return Merge_v2_magic_division{low_lengths}; #else - return DynamicMerge_v2r2_magic_division{low_lengths}; + return Merge_v2r2_magic_division{low_lengths}; #endif #endif } @@ -68,7 +67,7 @@ template __host__ __device__ constexpr auto make_merge_transform_v2_magic_division(const LowLengths& low_lengths) { - return DynamicMerge_v2_magic_division{low_lengths}; + return Merge_v2_magic_division{low_lengths}; } template @@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform( const UpLengths& up_lengths, integral_constant = integral_constant{}) { - return DynamicUnMerge{up_lengths}; + return UnMerge{up_lengths}; } template __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) { - return DynamicFreeze{low_idx}; + return Freeze{low_idx}; } template @@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len const SliceBegin& slice_begin, const SliceEnd& slice_end) { - return DynamicSlice{low_length, slice_begin, slice_end}; + return Slice{low_length, slice_begin, slice_end}; } template __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, const UpLength& up_length) { - return DynamicVectorize{vector_size, up_length}; + return Vectorize{vector_size, up_length}; } } // namespace ck diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp index 6affe6141f..f684ce5e0f 100644 --- a/composable_kernel/include/tensor_description/tensor_adaptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -2,8 +2,8 @@ #define CK_TENSOR_ADAPTOR_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf remove_cv_t>{transforms}; } -template = 2, bool>::type = false> +template = 2, bool>::type = false> __host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) { return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp similarity index 85% rename from composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp rename to composable_kernel/include/tensor_description/tensor_descriptor.hpp index b9ca26c879..4038ef63da 100644 --- a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -1,16 +1,16 @@ -#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP -#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP +#ifndef CK_TENSOR_DESCRIPTOR_HPP +#define CK_TENSOR_DESCRIPTOR_HPP #include "common_header.hpp" -#include "dynamic_multi_index_transform.hpp" +#include "multi_index_transform.hpp" namespace ck { template -struct DynamicTensorCoordinate; +struct TensorCoordinate; template -struct DynamicTensorCoordinateIterator; +struct TensorCoordinateStep; // Transforms: Tuple // LowerDimensionIdss : Tuple, ...> @@ -21,7 +21,7 @@ template -struct DynamicTensorDescriptor +struct TensorDescriptor { // TODO make these private __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } @@ -105,16 +105,16 @@ struct DynamicTensorDescriptor using VisibleIndex = MultiIndex; using HiddenIndex = MultiIndex; - using Coordinate = DynamicTensorCoordinate; + using Coordinate = TensorCoordinate; // may be index_t or Number<> using ElementSize = remove_cv_t; public: - __host__ __device__ constexpr DynamicTensorDescriptor() = default; + __host__ __device__ constexpr TensorDescriptor() = default; - __host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms, - ElementSpaceSize element_space_size) + __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, + ElementSpaceSize element_space_size) : transforms_{transforms}, element_size_{InitializeElementSize(transforms)}, element_space_size_{element_space_size} @@ -159,7 +159,7 @@ struct DynamicTensorDescriptor { static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); - return make_dynamic_tensor_coordinate(*this, idx).GetOffset(); + return make_tensor_coordinate(*this, idx).GetOffset(); } // TODO make these private @@ -196,7 +196,7 @@ struct DynamicTensorDescriptor __host__ __device__ void Print() const { printf("{"); - printf("DynamicTensorDescriptor, "); + printf("TensorDescriptor, "); static_for<0, ntransform_, 1>{}([&](auto i) { printf("transforms: "); transforms_[i].Print(); @@ -217,7 +217,7 @@ struct DynamicTensorDescriptor }; template -struct DynamicTensorCoordinate +struct TensorCoordinate { // TODO make these private static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); @@ -226,9 +226,9 @@ struct DynamicTensorCoordinate using VisibleIndex = MultiIndex; public: - __host__ __device__ constexpr DynamicTensorCoordinate() = default; + __host__ __device__ constexpr TensorCoordinate() = default; - __host__ __device__ constexpr DynamicTensorCoordinate(const HiddenIndex& idx_hidden) + __host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden) : idx_hidden_{idx_hidden} { } @@ -252,16 +252,16 @@ struct DynamicTensorCoordinate }; template -struct DynamicTensorCoordinateIterator +struct TensorCoordinateStep { // TODO make these private using VisibleIndex = MultiIndex; public: - __host__ __device__ constexpr DynamicTensorCoordinateIterator() = default; + __host__ __device__ constexpr TensorCoordinateStep() = default; - __host__ __device__ constexpr DynamicTensorCoordinateIterator( - const VisibleIndex& idx_diff_visible, const MultiIndex& do_transforms) + __host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible, + const MultiIndex& do_transforms) : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} { } @@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator // TODO: How to fix this? It uses an struct instead of lambda because lambda // doesn't have constructor, and to put it outside the scope where it is used -// (transform_dynamic_tensor_descriptor) because template cannot be defined inside a function +// (transform_tensor_descriptor) because template cannot be defined inside a function // template template struct lambda_get_up_dim_num @@ -301,10 +301,10 @@ template __host__ __device__ constexpr auto -transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, - const NewTransforms& new_transforms, - NewLowerDimensionOldVisibleIdss, - NewUpperDimensionNewVisibleIdss) +transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, + const NewTransforms& new_transforms, + NewLowerDimensionOldVisibleIdss, + NewUpperDimensionNewVisibleIdss) { // sanity check { @@ -376,17 +376,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); - return DynamicTensorDescriptor, - remove_cv_t, - remove_cv_t, - remove_cv_t, - remove_cv_t>{all_transforms, - element_space_size}; + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms, + element_space_size}; } template -__host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDesc& tensor_desc, - const VisibleIndex& idx_visible) +__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc, + const VisibleIndex& idx_visible) { static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), "wrong! # of dimension inconsistent"); @@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe set_container_subset(idx_hidden, dims_low, idx_low); }); - return DynamicTensorCoordinate{idx_hidden}; + return TensorCoordinate{idx_hidden}; } // UpdateLowerIndexHack: Sequence<...> // HACK: control UpdateLowerIndex template -__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( - const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible, + UpdateLowerIndexHack) { static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), "wrong! # of dimension inconsistent"); @@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); }); - return DynamicTensorCoordinateIterator{ - idx_diff_visible, do_transforms}; + return TensorCoordinateStep{idx_diff_visible, + do_transforms}; } template -__host__ __device__ constexpr auto -make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible) +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible) { constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); - return make_dynamic_tensor_coordinate_iterator( + return make_tensor_coordinate_step( TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen::type{}); } -template -__host__ __device__ constexpr void move_dynamic_tensor_coordinate( - const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator) +template +__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc, + TensorCoord& coord, + const TensorCoordStep& coord_step) { constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); @@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( auto idx_diff_hidden = make_zero_multi_index(); // initialize visible index diff - set_container_subset(idx_diff_hidden, - TensorDesc::GetVisibleDimensionIds(), - coord_iterator.GetVisibleIndexDiff()); + set_container_subset( + idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff()); // this is what needs to be updated auto& idx_hidden = coord.GetHiddenIndex(); @@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( auto idx_hidden_pick_visible = get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); - idx_hidden_pick_visible += coord_iterator.GetIndexDiff(); + idx_hidden_pick_visible += coord_step.GetIndexDiff(); set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); // update rest of hidden index static_for{}([&](auto itran) { - if(coord_iterator.do_transforms_[itran]) + if(coord_step.do_transforms_[itran]) { const auto& tran = tensor_desc.GetTransforms().At(itran); constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); @@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate( MultiIndex idx_diff_low; - // HACK: control UpdateLowerIndex for DynamicMerge using hack - constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran); + // HACK: control UpdateLowerIndex for Merge using hack + constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran); tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); @@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& } template -using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate( +using TensorCoordinate_t = decltype(make_tensor_coordinate( TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); template -using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator( +using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); } // namespace ck diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp similarity index 72% rename from composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp rename to composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp index 2e36451a66..cf329f06a5 100644 --- a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp @@ -1,9 +1,9 @@ -#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP -#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP +#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP +#define CK_TENSOR_DESCRIPTOR_HELPER_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "multi_index_transform_helper.hpp" namespace ck { @@ -37,10 +37,9 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt template ::type = false> -__host__ __device__ constexpr auto -make_dynamic_naive_tensor_descriptor_v2(const Tuple& lengths, - const Tuple& strides) + typename enable_if::type = false> +__host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple& lengths, + const Tuple& strides) { constexpr index_t N = sizeof...(Lengths); @@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple& lengths, calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); #endif - return DynamicTensorDescriptor, - remove_cv_t, - remove_cv_t, - remove_cv_t, - remove_cv_t>{transforms, - element_space_size}; + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; } // Lengths... can be: @@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple& lengths, // 2) Number<>, which is known at compile-time template __host__ __device__ constexpr auto -make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple& lengths) +make_naive_tensor_descriptor_packed(const Tuple& lengths) { constexpr index_t N = sizeof...(Lengths); @@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple& lengths) const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); - return DynamicTensorDescriptor, - remove_cv_t, - remove_cv_t, - remove_cv_t, - remove_cv_t>{transforms, - element_space_size}; + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; } template __host__ __device__ constexpr auto -make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple& lengths, Align align) +make_naive_tensor_descriptor_aligned_v2(const Tuple& lengths, Align align) { constexpr auto I1 = Number<1>{}; @@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple& lengths }, Number{}); - return make_dynamic_naive_tensor_descriptor_v2(lengths, strides); + return make_naive_tensor_descriptor_v2(lengths, strides); } } // namespace ck diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp index 694cf9c6a3..35ff66a2b0 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp @@ -3,7 +3,7 @@ #include "common_header.hpp" #include "tensor_adaptor.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_contraction_dlops.hpp" namespace ck { @@ -22,24 +22,24 @@ namespace ck { // 2. CThreadBuffer is StaticBuffer // Also assume: // M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) -template ::type = false> +template < + index_t BlockSize, + typename FloatA, + typename FloatB, + typename FloatC, + typename AKMBlockDesc, + typename BKNBlockDesc, + index_t M1PerThreadM11, + index_t N1PerThreadN11, + index_t KPerThread, + index_t M1N1ThreadClusterM100, + index_t M1N1ThreadClusterN100, + index_t M1N1ThreadClusterM101, + index_t M1N1ThreadClusterN101, + index_t AThreadCopyScalarPerVector_M11, + index_t BThreadCopyScalarPerVector_N11, + typename enable_if::type = false> struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 { using AIndex = MultiIndex<3>; @@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 static constexpr index_t N0 = N / N1; __host__ __device__ static constexpr auto - MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc) + MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */) { - const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( + const auto a_k_m0_m1_block_desc = transform_tensor_descriptor( AKMBlockDesc{}, make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform(make_tuple(Number{}, Number{}))), @@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 } __host__ __device__ static constexpr auto - MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc) + MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */) { - const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( + const auto b_k_n0_n1_block_desc = transform_tensor_descriptor( BKNBlockDesc{}, make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform(make_tuple(Number{}, Number{}))), @@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer> - __device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc, + __device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */, const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const @@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 private: // A[K, M0, M1] - static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, Number{})); // B[K, N0, N1] - static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, Number{})); - using AThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2>, - 2, - AThreadCopyScalarPerVector_M11, - 1>; + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + AThreadCopyScalarPerVector_M11, + 1>; - using BThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2>, - 2, - BThreadCopyScalarPerVector_N11, - 1>; + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + BThreadCopyScalarPerVector_N11, + 1>; CIndex c_thread_origin_data_idx_; diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp index 6a3885936e..26ca0bf111 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp @@ -3,7 +3,7 @@ #include "common_header.hpp" #include "tensor_adaptor.hpp" -#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_contraction_dlops.hpp" namespace ck { @@ -38,9 +38,9 @@ template index_t AThreadCopyScalarPerVector_BM11, index_t BThreadCopyScalarPerVector_BN11, - typename std::enable_if::type = false> + typename enable_if::type = false> struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 { using AIndex = MultiIndex<3>; @@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B __host__ __device__ static constexpr auto MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) { - const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor( + const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( a_block_desc_bk0_bm_bk1, make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform(make_tuple(Number{}, Number{})), @@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B __host__ __device__ static constexpr auto MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) { - const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor( + const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( b_block_desc_bk0_bn_bk1, make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform(make_tuple(Number{}, Number{})), @@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B private: // A[BK0, BM0, BM1, BK1] static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + make_naive_tensor_descriptor_packed(make_tuple( Number{}, Number{}, Number{}, Number{})); // B[BK0, BN0, BN1, BK1] static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + make_naive_tensor_descriptor_packed(make_tuple( Number{}, Number{}, Number{}, Number{})); - using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< FloatA, FloatA, decltype(a_block_desc_bk0_bm0_bm1_bk1_), @@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< FloatB, FloatB, decltype(b_block_desc_bk0_bn0_bn1_bk1_), diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp index 074d519b76..03f889649e 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp @@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 // HACK: fix this @Jing Zhang static constexpr index_t KPerThreadSubC = 4; - static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2( + static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( Number{}, Number<1>{}, Number{}, Number{})); - static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( Number{}, Number<1>{}, Number{}, Number{})); - using AThreadCopy = - ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1>, - 1, - ThreadGemmADataPerRead_K, - 1>; + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1>, + 1, + ThreadGemmADataPerRead_K, + 1>; __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, @@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 "wrong! K dimension not consistent\n"); constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed - constexpr index_t N = BlockMatrixB{}.GetLength(I1); constexpr index_t H = BlockMatrixB{}.GetLength(I2); constexpr index_t W = BlockMatrixB{}.GetLength(I3); @@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; constexpr auto a_block_mtx = BlockMatrixA{}; @@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 static_assert(WPerThread % WoPerThreadSubC == 0, ""); // thread A buffer for GEMM - StaticBuffer + StaticBuffer a_thread_buf; constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}, I1, Number{})); + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); // B[K, N] - static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - K1, - 1>; + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + K1, + 1>; - using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - K1, - 1>; + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + K1, + 1>; AThreadCopy a_thread_copy_; BThreadCopy b_thread_copy_; @@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline const index_t waveId = thread_id / WaveSize; const index_t laneId = thread_id % WaveSize; const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; if constexpr(xdlops_gemm.IsKReduction) { @@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline const index_t thread_id = get_thread_local_1d_id(); const index_t waveId = thread_id / WaveSize; const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; const index_t waveId_n = waveId % NWaves; if constexpr(xdlops_gemm.IsKReduction) @@ -490,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline private: // A[K, M] - static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); // B[K, N] - static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + 1, // K1, + 1>; - using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + 1, // K1, + 1>; AThreadCopy a_thread_copy_; BThreadCopy b_thread_copy_; diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp similarity index 66% rename from composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp rename to composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp index 694b2fd2cc..cf21123de6 100644 --- a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp @@ -1,18 +1,18 @@ -#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP -#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP +#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP +#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" #include "cluster_descriptor.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" namespace ck { // this version does following things to avoid scratch memory issue // 1. Use StaticallyIndexedArray instead of C array for thread buffer -// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor -// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate template -struct BlockwiseDynamicTensorSliceTransfer_v4 +struct BlockwiseTensorSliceTransfer_v4 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); using Index = MultiIndex; - __device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc, - const Index& src_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin) + __device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) : threadwise_transfer_( src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) @@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 } } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const SrcIteratorHacks& src_iterator_hacks) + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); } } @@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 } } - // SrcMoveSliceWindowIteratorHack to control index calculation move slice window - template + // SrcMoveSliceWindowStepHack to control index calculation move slice window + template __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step, - const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrcSliceWindow( - src_desc, step, src_move_slice_window_iterator_hack); + src_desc, step, src_move_slice_window_step_hack); } } @@ -147,22 +146,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4 make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseDynamicTensorSliceTransfer_v3; + ThreadwiseTensorSliceTransfer_v3; ThreadwiseTransfer threadwise_transfer_; }; diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp similarity index 67% rename from composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp rename to composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp index 20f3225f82..4f3336f9f7 100644 --- a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp @@ -1,18 +1,18 @@ -#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP -#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP +#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP +#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" #include "cluster_descriptor.hpp" -#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer_v2.hpp" namespace ck { // this version does following things to avoid scratch memory issue // 1. Use StaticallyIndexedArray instead of C array for thread buffer -// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor -// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate template -struct BlockwiseDynamicTensorSliceTransfer_v4r1 +struct BlockwiseTensorSliceTransfer_v4r1 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); using Index = MultiIndex; - __device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1( - const SrcDesc& src_desc, - const Index& src_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin) + __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) : threadwise_transfer_( src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) @@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 } } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const SrcIteratorHacks& src_iterator_hacks) + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); } } @@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 } } - // SrcMoveSliceWindowIteratorHack to control index calculation move slice window - template + // SrcMoveSliceWindowStepHack to control index calculation move slice window + template __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step, - const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrcSliceWindow( - src_desc, step, src_move_slice_window_iterator_hack); + src_desc, step, src_move_slice_window_step_hack); } } @@ -136,20 +134,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1 make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); using ThreadwiseTransfer = - ThreadwiseDynamicTensorSliceTransfer_v3r1; + ThreadwiseTensorSliceTransfer_v3r1; ThreadwiseTransfer threadwise_transfer_; }; diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp similarity index 89% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp rename to composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp index 6d48a18169..366451dcc3 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_dlops_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp @@ -1,14 +1,14 @@ -#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP -#define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP +#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP +#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP #include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_dlops_v2r3.hpp" -#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" +#include "blockwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" namespace ck { @@ -25,7 +25,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_contraction_dlops_v1r2( + kernel_contraction_dlops_v1r2( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -84,12 +84,12 @@ template -struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> +struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -110,17 +110,15 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = - make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GM0, I1, Number{}, GK1), - max_lds_align); + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = - make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GN0, I1, Number{}, GK1), - max_lds_align); + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = math::integer_least_multiple( @@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 const auto GM11 = Number{}; const auto GM10 = GM1 / GM11; - const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor( + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor( a_grid_desc_gk0_gm0_gm1_gk1, make_tuple(make_pass_through_transform(GK0), make_pass_through_transform(GM0), @@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 const auto GN11 = Number{}; const auto GN10 = GN1 / GN11; - const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor( + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor( b_grid_desc_gk0_gn0_gn1_gk1, make_tuple(make_pass_through_transform(GK0), make_pass_through_transform(GN0), @@ -259,7 +257,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 constexpr auto BM0 = BM / BM1; constexpr auto BN0 = BN / BN1; - const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( + const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor( c_grid_desc_gm0_gm1_gn0_gn1, make_tuple(make_pass_through_transform(GM0), make_unmerge_transform(make_tuple(GM10, GM11)), @@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); - const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor( + const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor( c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, make_tuple(make_pass_through_transform(GM10), make_merge_transform(make_tuple(GM0, GM11)), @@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor( + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor( c_gm10_bm_gn10_bn_grid_desc, make_tuple(make_pass_through_transform(GM10), make_unmerge_transform(make_tuple(BM0, BM1)), @@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = - make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GM0, I1, Number{}, GK1), - max_lds_align); + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = - make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GN0, I1, Number{}, GK1), - max_lds_align); + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); // A matrix in LDS memory for blockwise GEMM // be careful of LDS alignment - constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); // B matrix in LDS memory for blockwise GEMM // be careful of LDS alignment - constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == @@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 "wrong!"); // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< BlockSize, InMemoryDataOperationEnum_t::Set, Sequence, @@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 make_multi_index(0, 0, 0, 0, 0)); // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< BlockSize, InMemoryDataOperationEnum_t::Set, Sequence, @@ -457,9 +453,8 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = - make_dynamic_naive_tensor_descriptor_packed_v2( - sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); + constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = math::integer_least_multiple( @@ -475,9 +470,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 auto c_thread_buf = make_static_buffer( c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); - ThreadwiseDynamicTensorSliceSet_v1{} + ThreadwiseTensorSliceSet_v1{} .Run(c_thread_desc_bm0_bm1_bn0_bn1, make_tuple(I0, I0, I0, I0), c_thread_buf, @@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 // LDS double buffer: preload data into LDS { a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); @@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 // even iteration a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, @@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 // odd iteration a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run( @@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 { a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS double buffer: load last data from device mem a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( @@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 // output: register to global memory { constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = - make_dynamic_naive_tensor_descriptor_packed_v2( + make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, Number{}, @@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id()); - ThreadwiseDynamicTensorSliceTransfer_v1r3< + ThreadwiseTensorSliceTransfer_v1r3< FloatAcc, FloatC, decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), @@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0 c_thread_buf, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, c_grid_buf, - CGridIteratorHacks{}); + CGridStepHacks{}); } } }; diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp similarity index 74% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp rename to composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp index 7a4ef1d7ea..31a0fa342a 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp @@ -1,14 +1,14 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP +#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP +#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP #include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_dlops_v2r2.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" namespace ck { @@ -26,7 +26,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_dlops_v1r2( + kernel_gemm_dlops_v1r2( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -68,28 +68,27 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_dlops_v1r2( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k_m0_m1_grid_desc, - const void CONSTANT* p_b_k_n0_n1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) + kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k_m0_m1_grid_desc, + const void CONSTANT* p_b_k_n0_n1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) { // first cast void CONSTANT void* to void* // second cast void* to Desc* // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k_m0_m1_grid_desc = - *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); - const auto b_k_n0_n1_grid_desc = - *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); + const auto a_k_m0_m1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc)); + const auto b_k_n0_n1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc)); const auto c_m0_m10_m11_n0_n10_n11_grid_desc = *reinterpret_cast( - (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc)); const auto c_blockid_to_m0_n0_block_cluster_adaptor = *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor)); constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -146,12 +145,12 @@ template -struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> +struct GridwiseGemmDlops_km_kn_mn_v1r2 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}), max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}), max_lds_align); // LDS allocation for A and B: be careful of alignment @@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 const auto M1 = Number{}; const auto M0 = M / M1; - const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor( + const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor( a_k_m_grid_desc, make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 const auto N1 = Number{}; const auto N0 = N / N1; - const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor( + const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor( b_k_n_grid_desc, make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 constexpr auto M10 = M1 / M11; constexpr auto N10 = N1 / N11; - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( c_m_n_grid_desc, make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_unmerge_transform(make_tuple(N0, N10, N11))), @@ -352,75 +351,75 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}), max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}), max_lds_align); // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, I1, Number{}), max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, I1, Number{}), max_lds_align); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K_M0_M1, - ABlockTransferThreadClusterLengths_K_M0_M1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k_m0_m1_grid_desc), - decltype(a_k_m0_m1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_M1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_k_m0_m1_grid_desc, - make_multi_index(0, im0, 0), - a_k_m0_m1_block_desc, - make_multi_index(0, 0, 0)); + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K_M0_M1, + ABlockTransferThreadClusterLengths_K_M0_M1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k_m0_m1_grid_desc), + decltype(a_k_m0_m1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_M1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_k_m0_m1_grid_desc, + make_multi_index(0, im0, 0), + a_k_m0_m1_block_desc, + make_multi_index(0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K_N0_N1, - BBlockTransferThreadClusterLengths_K_N0_N1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k_n0_n1_grid_desc), - decltype(b_k_n0_n1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_N1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_k_n0_n1_grid_desc, - make_multi_index(0, in0, 0), - b_k_n0_n1_block_desc, - make_multi_index(0, 0, 0)); + BlockwiseTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K_N0_N1, + BBlockTransferThreadClusterLengths_K_N0_N1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k_n0_n1_grid_desc), + decltype(b_k_n0_n1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_N1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_k_n0_n1_grid_desc, + make_multi_index(0, in0, 0), + b_k_n0_n1_block_desc, + make_multi_index(0, 0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -447,9 +446,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); - constexpr auto c_m10_m11_n10_n11_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( - sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = @@ -465,9 +463,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 auto c_thread_buf = make_static_buffer( c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); - ThreadwiseDynamicTensorSliceSet_v1{} + ThreadwiseTensorSliceSet_v1{} .Run(c_m10_m11_n10_n11_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, @@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; - constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; + constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{}; + constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{}; // hack to control index calculation when move slice window for A and B matrix for // threadwise copy - constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = - AGridMoveSliceWindowIteratorHacks{}; - constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = - BGridMoveSliceWindowIteratorHacks{}; + constexpr auto a_k_m0_m1_global_move_slice_window_step_hack = + AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k_n0_n1_global_move_slice_window_step_hack = + BGridMoveSliceWindowStepHacks{}; auto a_block_even_buf = make_dynamic_buffer( p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); @@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 // LDS double buffer: preload data into LDS { a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); @@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 do { // even iteration - a_blockwise_copy.MoveSrcSliceWindow( - a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow( - b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); __syncthreads(); // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, @@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); // odd iteration - a_blockwise_copy.MoveSrcSliceWindow( - a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow( - b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); __syncthreads(); // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); // LDS double buffer: GEMM on current data blockwise_gemm.Run( @@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 { a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); + a_k_m0_m1_global_move_slice_window_step_hack); b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); + b_k_n0_n1_global_move_slice_window_step_hack); __syncthreads(); // LDS double buffer: load last data from device mem a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( @@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 // output: register to global memory { - constexpr index_t M11 = - M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; - constexpr index_t N11 = - N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; - - constexpr index_t M10 = MPerBlockM1 / M11; - constexpr index_t N10 = NPerBlockN1 / N11; - - constexpr index_t M111 = M1PerThreadM111; - constexpr index_t N111 = N1PerThreadN111; - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( + make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, Number{}, @@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); - ThreadwiseDynamicTensorSliceTransfer_v1r3< + ThreadwiseTensorSliceTransfer_v1r3< FloatAcc, FloatC, decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), @@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2 c_thread_buf, c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_buf, - CGridIteratorHacks{}); + CGridStepHacks{}); } } }; diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp similarity index 82% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp rename to composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp index db3cb99121..1017dcc2a1 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v1r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp @@ -1,14 +1,14 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP +#ifndef CK_GRIDWISE_GEMM_V1R3_HPP +#define CK_GRIDWISE_GEMM_V1R3_HPP #include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_dlops_v2r3.hpp" -#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" -#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" +#include "blockwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_set.hpp" namespace ck { @@ -26,7 +26,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_dlops_v1r3( + kernel_gemm_dlops_v1r3( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -68,28 +68,27 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_dlops_v1r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc, - const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) + kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc, + const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) { // first cast void CONSTANT void* to void* // second cast void* to Desc* // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k0_m0_m1_k1_grid_desc = - *reinterpret_cast((const void*)p_a_k0_m0_m1_k1_grid_desc); - const auto b_k0_n0_n1_k1_grid_desc = - *reinterpret_cast((const void*)p_b_k0_n0_n1_k1_grid_desc); + const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc)); + const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc)); const auto c_m0_m10_m11_n0_n10_n11_grid_desc = *reinterpret_cast( - (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc)); const auto c_blockid_to_m0_n0_block_cluster_adaptor = *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor)); constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -142,12 +141,12 @@ template -struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> +struct GridwiseGemmDlops_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 // TODO: check alignment // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment @@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 const auto M = a_k0_m_k1_grid_desc.GetLength(I1); const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2); // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && K0 == b_k0_n_k1_grid_desc.GetLength(I0) && + K1 == a_k0_m_k1_grid_desc.GetLength(I2) && K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); } @@ -231,13 +230,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 const auto M1 = Number{}; const auto M0 = M / M1; - const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor( - a_k0_m_k1_grid_desc, - make_tuple(make_pass_through_transform(K0), - make_unmerge_transform(make_tuple(M0, M1)), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + const auto a_k0_m0_m1_k1_grid_desc = + transform_tensor_descriptor(a_k0_m_k1_grid_desc, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); return a_k0_m0_m1_k1_grid_desc; } @@ -251,13 +250,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 const auto N1 = Number{}; const auto N0 = N / N1; - const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor( - b_k0_n_k1_grid_desc, - make_tuple(make_pass_through_transform(K0), - make_unmerge_transform(make_tuple(N0, N1)), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + const auto b_k0_n0_n1_k1_grid_desc = + transform_tensor_descriptor(b_k0_n_k1_grid_desc, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); return b_k0_n0_n1_k1_grid_desc; } @@ -284,7 +283,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 constexpr auto M10 = M1 / M11; constexpr auto N10 = N1 / N11; - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( c_m_n_grid_desc, make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_unmerge_transform(make_tuple(N0, N10, N11))), @@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 // TODO: check alignment // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, I1, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, I1, Number{}, K1), max_lds_align); // TODO: check alignment // A matrix in LDS memory, for blockwise GEMM - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, for blockwise GEMM - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}, K1), max_lds_align); static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == @@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 "wrong!"); // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< BlockSize, InMemoryDataOperationEnum_t::Set, Sequence, @@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 make_multi_index(0, 0, 0, 0)); // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< BlockSize, InMemoryDataOperationEnum_t::Set, Sequence, @@ -453,9 +452,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - constexpr auto c_m10_m11_n10_n11_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( - sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = math::integer_least_multiple( @@ -471,9 +469,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 auto c_thread_buf = make_static_buffer( c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); - ThreadwiseDynamicTensorSliceSet_v1{} + ThreadwiseTensorSliceSet_v1{} .Run(c_m10_m11_n10_n11_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, @@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 // LDS double buffer: preload data into LDS { - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); @@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 // even iteration a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); - b_blockwise_copy.RunRead( - b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, @@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 // odd iteration a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); + AGridMoveSliceWindowStepHacks{}); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); - b_blockwise_copy.RunRead( - b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run( @@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 // LDS double buffer: tail if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowIteratorHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowIteratorHacks{}); + a_blockwise_copy.MoveSrcSliceWindow( + a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow( + b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{}); __syncthreads(); // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( @@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 // output: register to global memory { - constexpr auto M11 = - Number{}; - constexpr auto N11 = - Number{}; - - constexpr index_t M10 = MPerBlockM1 / M11; - constexpr index_t N10 = NPerBlockN1 / N11; - - constexpr index_t M111 = M1PerThreadM111; - constexpr index_t N111 = N1PerThreadN111; - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( + make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, Number{}, @@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id()); - ThreadwiseDynamicTensorSliceTransfer_v1r3< + ThreadwiseTensorSliceTransfer_v1r3< FloatAcc, FloatC, decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), @@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3 c_thread_buf, c_m0_m10_m11_n0_n10_n11_grid_desc, c_grid_buf, - CGridIteratorHacks{}); + CGridStepHacks{}); } } }; diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp similarity index 74% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp rename to composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp index 34dea34833..b141307b77 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_dlops_v2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp @@ -1,12 +1,12 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP +#ifndef CK_GRIDWISE_GEMM_V2_HPP +#define CK_GRIDWISE_GEMM_V2_HPP #include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" #include "blockwise_gemm_dlops_v3.hpp" namespace ck { @@ -42,12 +42,12 @@ template -struct GridwiseDynamicGemmDlops_km_kn_mn_v3 + typename AGlobalStepHacks, + typename BGlobalStepHacks, + typename CGlobalStepHacks, + typename AGlobalMoveSliceWindowStepHacks, + typename BGlobalMoveSliceWindowStepHacks> +struct GridwiseGemmDlops_km_kn_mn_v3 { __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}), max_lds_align); // LDS allocation for A and B: be careful of alignment @@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 // divide block work by [M, N] #if 0 - const auto k_block_work_num = K / Number{}; const auto ho_block_work_num = Ho / Number{}; const auto wo_block_work_num = Wo / Number{}; const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; @@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; #else // Hack: this force result into SGPR - const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock); const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; @@ -134,23 +132,21 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}), max_lds_align); - constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}), max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_e_n_ho_wo_block_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); + constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); // c_thread_mtx definition: this is a mess // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k_n_ho_wo_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); + constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); auto blockwise_gemm = BlockwiseGemmDlops_km_kn_m0m1n0n1_v3, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_e_k_global_desc), - decltype(a_e_k_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1>, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_e_k_global_desc, - make_multi_index(0, k_block_data_on_global), - a_e_k_desc, - make_multi_index(0, 0)); + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_e_k_global_desc), + decltype(a_e_k_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1>, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_e_k_global_desc, + make_multi_index(0, k_block_data_on_global), + a_e_k_desc, + make_multi_index(0, 0)); - constexpr auto b_e_n_ho_wo_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); + constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); - auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2< - FloatAB, - FloatAB, - decltype(b_e_n_ho_wo_global_desc), - decltype(b_e_n_ho_wo_thread_desc), - Sequence, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - 1, - true>(b_e_n_ho_wo_global_desc, - make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); + auto b_threadwise_transfer = + ThreadwiseTensorSliceTransfer_v2, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + 1, + true>( + b_e_n_ho_wo_global_desc, + make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); auto a_block_buf = make_dynamic_buffer( p_shared_block, a_e_k_desc.GetElementSpaceSize()); @@ -232,44 +227,45 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 // register allocation for output StaticBuffer + c_k_n_ho_wo_thread_desc.GetElementSpaceSize(), + true> c_thread_buf; // initialize output thread tensor - ThreadwiseDynamicTensorSliceSet_v1>{} + ThreadwiseTensorSliceSet_v1>{} .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{}; - constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{}; + constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{}; + constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{}; // hack to control index calculation when move slice window for A and B matrix for // threadwise copy - constexpr auto a_e_k_global_move_slice_window_iterator_hack = - AGlobalMoveSliceWindowIteratorHacks{}; - constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = - BGlobalMoveSliceWindowIteratorHacks{}; + constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{}; + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = + BGlobalMoveSliceWindowStepHacks{}; // double regsiter buffer for b StaticBuffer + b_e_n_ho_wo_thread_desc.GetElementSpaceSize(), + true> b_thread_even_buf, b_thread_odd_buf; // LDS double buffer: preload data { - a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks); + a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks); b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, b_global_buf, b_e_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), b_thread_even_buf, - b_e_n_ho_wo_global_iterator_hacks); + b_e_n_ho_wo_global_step_hacks); a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); } @@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 b_e_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), b_thread_odd_buf, - b_e_n_ho_wo_global_iterator_hacks); + b_e_n_ho_wo_global_step_hacks); // LDS double buffer: GEMM on current data // TODO: @Zhang Jing: blockwise gemm should be able to move slice window @@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 b_e_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), b_thread_even_buf, - b_e_n_ho_wo_global_iterator_hacks); + b_e_n_ho_wo_global_step_hacks); // LDS double buffer: GEMM on current data blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); @@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 b_e_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), b_thread_odd_buf, - b_e_n_ho_wo_global_iterator_hacks); + b_e_n_ho_wo_global_step_hacks); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); @@ -351,23 +347,22 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 // output: register to global memory { // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor - constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{}; const index_t k_thread_data_on_global = k_block_data_on_global + k_thread_id * KPerThread; - ThreadwiseDynamicTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_k_n_ho_wo_thread_desc), - decltype(c_k_n_ho_wo_global_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>( + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>( c_k_n_ho_wo_global_desc, make_multi_index( k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) @@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3 c_thread_buf, c_k_n_ho_wo_global_desc, c_global_buf, - c_k_n_ho_wo_global_tensor_iterator_hacks); + c_k_n_ho_wo_global_tensor_step_hacks); } } diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp similarity index 74% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp rename to composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index a5b1de79a7..dcb16e5dcd 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -1,14 +1,14 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP #include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" namespace ck { @@ -24,13 +24,13 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0MK1GridDesc a_k0_m_k1_grid_desc, - const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AK0MK1GridDesc a_k0_m_k1_grid_desc, + const BK0NK1GridDesc b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor c_block_cluster_adaptor) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -58,25 +58,25 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, - const void CONSTANT* p_c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_block_cluster_adaptor) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - const auto a_k0_m_k1_grid_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); - const auto b_k0_n_k1_grid_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); - const auto c_m0_m1_m2_n_grid_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); - const auto c_block_cluster_adaptor = - *reinterpret_cast((const void*)p_c_block_cluster_adaptor); + const auto a_k0_m_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); + const auto b_k0_n_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); + const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc)); + const auto c_block_cluster_adaptor = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); __shared__ FloatAB p_shared_block[shared_block_size]; @@ -126,13 +126,13 @@ template -struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -148,12 +148,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}, K1), max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}, K1), max_lds_align); // LDS allocation for A and B: be careful of alignment @@ -203,9 +203,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 __host__ __device__ static constexpr auto MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - constexpr auto xdlops_gemm = XdlopsGemm{}; constexpr auto CLayout = xdlops_gemm.GetCLayout(); @@ -217,10 +214,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); - constexpr auto N0 = Number{}; constexpr auto N1 = Number{}; - const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor( + const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor( c_m_n_grid_desc, make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), @@ -269,11 +265,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, const CBlockClusterAdaptor& c_block_cluster_adaptor) { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -282,8 +273,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); // divide block work by [M, N] const auto block_work_idx = @@ -301,67 +290,65 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}, K1), max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2( make_tuple(Number{}, Number{}, K1), max_lds_align); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K0_M_K1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k0_m_k1_grid_desc), - decltype(a_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_k0_m_k1_grid_desc, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_k0_m_k1_block_desc, - make_multi_index(0, 0, 0)); + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_grid_desc), + decltype(a_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_k0_m_k1_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_k0_m_k1_block_desc, + make_multi_index(0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K0_N_K1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k0_n_k1_grid_desc), - decltype(b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_k0_n_k1_grid_desc, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_k0_n_k1_block_desc, - make_multi_index(0, 0, 0)); + BlockwiseTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_grid_desc), + decltype(b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_k0_n_k1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -375,7 +362,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 NPerBlock % (NPerWave * NRepeat) == 0, "wrong!"); - constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( + constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor( a_k0_m_k1_block_desc, make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform( @@ -384,7 +371,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( + constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor( b_k0_n_k1_block_desc, make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform( @@ -410,21 +397,19 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); - constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{})); + constexpr auto c_mr_nr_blk_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); StaticBuffer, - c_mr_nr_blk_desc.GetElementSpaceSize()> + c_mr_nr_blk_desc.GetElementSpaceSize(), + true> c_thread_buf; // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); - FloatAB* p_a_block = p_shared_block; FloatAB* p_b_block = p_shared_block + a_block_space_size; @@ -432,15 +417,13 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{}; - constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{}; + constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; + constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; // hack to control index calculation when move slice window for A and B matrix for // threadwise copy - constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack = - AGridMoveSliceWindowIteratorHacks{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack = - BGridMoveSliceWindowIteratorHacks{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; auto a_block_buf = make_dynamic_buffer( p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); @@ -449,10 +432,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // preload data into LDS { - a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); - b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); @@ -465,18 +446,16 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_iterator_hack); + a_k0_m_k1_grid_move_slice_window_step_hack); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_iterator_hack); + b_k0_n_k1_grid_move_slice_window_step_hack); - a_blockwise_copy.RunRead( - a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); + a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); block_sync_lds(); - b_blockwise_copy.RunRead( - b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); @@ -506,7 +485,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr index_t N1 = CLayout.N0(); constexpr auto c_m0_m1_m2_n_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number{}, + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{}, Number<1>{}, Number<1>{}, @@ -515,7 +494,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 Number{}, Number<1>{})); - StaticBuffer + StaticBuffer c_blk_buf_; static_for<0, MRepeat, 1>{}([&](auto mr_i) { @@ -542,12 +521,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; + constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); - ThreadwiseDynamicTensorSliceTransfer_v1r3< + ThreadwiseTensorSliceTransfer_v1r3< FloatC, FloatC, decltype(c_m0_m1_m2_n_thread_desc), @@ -573,7 +552,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_blk_buf_, c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); } #else { @@ -581,11 +560,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr index_t M1 = CLayout.N1(); constexpr index_t M2 = CLayout.M0(); - constexpr auto c_m0_m1_m2_n_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( - I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); - - StaticBuffer c_blk_buf_; + constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -598,20 +574,20 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; + constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; auto c_thread_copy = - ThreadwiseDynamicTensorSliceTransfer_v1r3, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ c_m0_m1_m2_n_grid_desc, make_multi_index(0, 0, @@ -629,7 +605,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); return c_thread_idx_; }; @@ -644,7 +620,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); }; auto nrepeat_plus_copy = [&](auto c_thread_idx_) { @@ -657,7 +633,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); }; auto mrepeat_minus_copy = [&](auto c_thread_idx_) { @@ -670,7 +646,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); }; auto nrepeat_minus_copy = [&](auto c_thread_idx_) { @@ -683,7 +659,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 c_thread_buf[Number{}].template AsType(), c_m0_m1_m2_n_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_iterator_hacks); + c_m0_m1_m2_n_grid_tensor_step_hacks); }; static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp index 7e7bb9c8c3..a925a5cd68 100644 --- a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp @@ -21,10 +21,10 @@ template ::type = false> + typename enable_if::type = false> struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 { __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() @@ -97,10 +97,9 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); - amd_inner_product_dlop( - a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); }); }); }); @@ -124,10 +123,10 @@ template ::type = false> + typename enable_if::type = false> struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 { __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() @@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); - amd_inner_product_dlop( + inner_product( a_vec.template AsType()[I0], b_vec.template AsType()[I0], c_buf(Number{})); diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp index 153d512df7..015ad675fb 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp @@ -19,9 +19,9 @@ template ::type = false> + typename enable_if::type = false> struct ThreadwiseGemmDlops_km_kn_mn_v3 { template {}; constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; constexpr auto E = ADesc{}.GetLength(I0); constexpr auto K = ADesc{}.GetLength(I1); diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp similarity index 78% rename from composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp rename to composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp index f1b632aa84..0c7aa978a7 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp @@ -1,9 +1,9 @@ -#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP -#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP +#ifndef CK_THREADWISE_TENSOR_SET_HPP +#define CK_THREADWISE_TENSOR_SET_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -11,12 +11,12 @@ namespace ck { // 1. Desc is known at compile-time // 2. Buffer is StaticBuffer // 3. OriginIdx is known at compile-time -// 4. use #-iterator +// 4. use #-step template ::type = false> -struct ThreadwiseDynamicTensorSliceSet_v1 + typename enable_if::type = false> +struct ThreadwiseTensorSliceSet_v1 { static constexpr index_t nDim = SliceLengths::Size(); @@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1 constexpr auto origin_idx = to_multi_index(OriginIdx{}); static_ford{}([&](auto access_idx) { - constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx); + constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx); constexpr bool is_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp similarity index 80% rename from composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp rename to composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 9626113686..0071accf7f 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -1,9 +1,9 @@ -#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP -#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -57,20 +57,20 @@ template ::type = false> -struct ThreadwiseDynamicTensorSliceTransfer_v1r3 + typename enable_if::type = false> +struct ThreadwiseTensorSliceTransfer_v1r3 { static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; - using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3( - const DstDesc& dst_desc, const Index& dst_slice_origin_idx) - : dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx)) + __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, + const Index& dst_slice_origin_idx) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)) { static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); @@ -78,19 +78,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) { - dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } template + typename DstStepHacks> __device__ void Run(const SrcDesc&, const SrcSliceOriginIdx&, const SrcBuffer& src_buf, const DstDesc& dst_desc, DstBuffer& dst_buf, - const DstIteratorHacks& dst_iterator_hacks) + const DstStepHacks& dst_step_hacks) { static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); @@ -127,31 +127,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 constexpr auto ordered_access_lengths = container_reorder_given_new2old(access_lengths, dim_access_order); - // make forward iterators - const auto dst_forward_iterators = generate_tuple( + // make forward steps + const auto dst_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; }); - return make_dynamic_tensor_coordinate_iterator( - dst_desc, forward_step, dst_iterator_hacks[I0][i]); + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto dst_backward_iterators = generate_tuple( + // make backward steps + const auto dst_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; }); - return make_dynamic_tensor_coordinate_iterator( - dst_desc, backward_step, dst_iterator_hacks[I1][i]); + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); }, Number{}); @@ -235,13 +235,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 { if constexpr(forward_sweep[i]) { - move_dynamic_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]); + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); } else { - move_dynamic_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]); + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); } } }); @@ -250,10 +250,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 // move dst coordinate back to slice origin (or not) if constexpr(DstResetCoordinateAfterRun) { - const auto dst_reset_iterator = - make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); } } @@ -268,11 +268,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto dst_iterator_hacks = + constexpr auto dst_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks); + Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks); } __device__ static constexpr auto GetDstCoordinateResetStep() @@ -345,10 +345,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = - make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); } private: @@ -374,20 +373,20 @@ template ::type = false> -struct ThreadwiseDynamicTensorSliceTransfer_v2 + typename enable_if::type = false> +struct ThreadwiseTensorSliceTransfer_v2 { static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; - using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc, - const Index& src_slice_origin_idx) - : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx)) + __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, + const Index& src_slice_origin_idx) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)) { static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); @@ -395,19 +394,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 __device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) { - src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); } template + typename SrcStepHacks> __device__ void Run(const SrcDesc& src_desc, const SrcBuffer& src_buf, const DstDesc&, const DstSliceOriginIdx&, DstBuffer& dst_buf, - const SrcIteratorHacks& src_iterator_hacks) + const SrcStepHacks& src_step_hacks) { static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! DstDesc need to known at compile-time"); @@ -442,31 +441,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 constexpr auto ordered_access_lengths = container_reorder_given_new2old(access_lengths, dim_access_order); - // make forward iterators - const auto src_forward_iterators = generate_tuple( + // make forward steps + const auto src_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; }); - return make_dynamic_tensor_coordinate_iterator( - src_desc, forward_step, src_iterator_hacks[I0][i]); + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto src_backward_iterators = generate_tuple( + // make backward steps + const auto src_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; }); - return make_dynamic_tensor_coordinate_iterator( - src_desc, backward_step, src_iterator_hacks[I1][i]); + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); }, Number{}); @@ -548,13 +547,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 { if constexpr(forward_sweep[i]) { - move_dynamic_tensor_coordinate( - src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]); + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[dim_access_order[i]]); } else { - move_dynamic_tensor_coordinate( - src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]); + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[dim_access_order[i]]); } } }); @@ -563,10 +562,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 // move src coordinate back to slice origin (or not) if constexpr(SrcResetCoordinateAfterRun) { - const auto src_reset_iterator = - make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); } } @@ -581,11 +580,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto src_iterator_hacks = + constexpr auto src_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks); + Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -658,10 +657,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = - make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } private: @@ -693,23 +691,23 @@ template // control whether to move back dst coordinate after each // RunWrite(), will be fused with MoveDstSliceWindow to // save addr computation -struct ThreadwiseDynamicTensorSliceTransfer_v3 +struct ThreadwiseTensorSliceTransfer_v3 { static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; - using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); - using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); - using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3(const SrcDesc& src_desc, - const Index& src_slice_origin, - const DstDesc& dst_desc, - const Index& dst_slice_origin) - : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), - dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) + __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) { // TODO: fix this static_assert(is_same::value, @@ -718,18 +716,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) { - src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); } __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) { - dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const SrcIteratorHacks& src_iterator_hacks) + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -757,31 +754,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward iterators - const auto src_forward_iterators = generate_tuple( + // make forward steps + const auto src_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; }); - return make_dynamic_tensor_coordinate_iterator( - src_desc, forward_step, src_iterator_hacks[I0][i]); + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto src_backward_iterators = generate_tuple( + // make backward steps + const auto src_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; }); - return make_dynamic_tensor_coordinate_iterator( - src_desc, backward_step, src_iterator_hacks[I1][i]); + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); }, Number{}); @@ -862,13 +859,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 { if constexpr(forward_sweep[i]) { - move_dynamic_tensor_coordinate( - src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); } else { - move_dynamic_tensor_coordinate( - src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); } } }); @@ -877,17 +874,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // move src coordinate back to slice origin (or not) if constexpr(SrcResetCoordinateAfterRun) { - const auto src_reset_iterator = - make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); } } - template - __device__ void RunWrite(const DstDesc& dst_desc, - DstBuffer& dst_buf, - const DstIteratorHacks& dst_iterator_hacks) + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) { static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -915,35 +911,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward iterators - const auto dst_forward_iterators = generate_tuple( + // make forward steps + const auto dst_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; }); - const auto forward_iterator = make_dynamic_tensor_coordinate_iterator( - dst_desc, forward_step, dst_iterator_hacks[I0][i]); - - return forward_iterator; + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto dst_backward_iterators = generate_tuple( + // make backward steps + const auto dst_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; }); - const auto backward_iterator = make_dynamic_tensor_coordinate_iterator( - dst_desc, backward_step, dst_iterator_hacks[I1][i]); - - return backward_iterator; + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); }, Number{}); @@ -1026,13 +1018,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 { if constexpr(forward_sweep[i]) { - move_dynamic_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); } else { - move_dynamic_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); } } }); @@ -1041,10 +1033,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // move dst coordinate back to slice origin (or not) if constexpr(DstResetCoordinateAfterRun) { - const auto dst_reset_iterator = - make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); } } @@ -1055,11 +1047,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto src_iterator_hacks = + constexpr auto src_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - RunRead(src_desc, src_buf, src_iterator_hacks); + RunRead(src_desc, src_buf, src_step_hacks); } template @@ -1069,11 +1061,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto dst_iterator_hacks = + constexpr auto dst_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - RunWrite(dst_desc, dst_buf, dst_iterator_hacks); + RunWrite(dst_desc, dst_buf, dst_step_hacks); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -1206,18 +1198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = - make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template + template __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) { // if src coord was not reset by RunRead(), then need to adjust the step here const auto adjusted_step_idx = @@ -1225,10 +1216,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( - src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); - move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, @@ -1240,19 +1231,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = - make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); } private: static constexpr auto buffer_desc_ = - make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{})); + make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticBuffer buffer_; + StaticBuffer buffer_; SrcCoord src_coord_; DstCoord dst_coord_; @@ -1264,37 +1254,36 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 // 2. SrcBuffer is DynamicBuffer // 3. src_ref_idx is known at run-time // 4. SrcRefToOriginDisplacement is known at compile-time -// 5. use #-iterator +// 5. use #-step // 2. dst: // 1. DstDesc is known at compile-time // 2. DstBuffer is StaticBuffer // 3. DstOriginIdx is known at compile-time // 4. use direct address calculation // 3. vector access on src -template < - typename SrcData, - typename DstData, - typename SrcDesc, - typename DstDesc, - typename SliceLengths, - typename DimAccessOrder, - index_t SrcVectorDim, - index_t SrcScalarPerVector, - index_t SrcScalarStrideInVector, - typename std::enable_if::type = false> -struct ThreadwiseDynamicTensorSliceTransfer_v4 +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4 { static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; - using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx) - : src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) + __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) { static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); @@ -1390,13 +1379,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 constexpr auto src_ref_to_data_disp_idx = src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - constexpr auto src_ref_to_data_disp_coord_iterator = - make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); auto src_data_coord = src_ref_coord_; - move_dynamic_tensor_coordinate( - src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); vector_type_maker_t src_tmp_vector; @@ -1435,10 +1423,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 { constexpr auto src_desc = SrcDesc{}; - const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator( - src_desc, to_multi_index(src_slice_move_step_idx)); + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); - move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); } private: diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp similarity index 77% rename from composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp rename to composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp index ba60e26c38..f069540343 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp @@ -1,9 +1,9 @@ -#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP -#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" namespace ck { @@ -30,7 +30,7 @@ template // control whether to move back dst coordinate after each // RunWrite(), will be fused with MoveDstSliceWindow to // save addr computation -struct ThreadwiseDynamicTensorSliceTransfer_v3r1 +struct ThreadwiseTensorSliceTransfer_v3r1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -38,18 +38,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; - using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); - using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); - using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3r1(const SrcDesc& src_desc, - const Index& src_slice_origin, - const DstDesc& dst_desc, - const Index& dst_slice_origin) - : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), - dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) { // TODO: fix this static_assert(is_same::value, @@ -64,18 +64,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) { - src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); } __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) { - dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const SrcIteratorHacks& src_iterator_hacks) + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) { static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -96,9 +95,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 I1), SrcVectorTensorContiguousDimOrder{}); - constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2( - sequence_to_tuple_of_number(src_vector_tensor_lengths), - sequence_to_tuple_of_number(src_vector_tensor_strides)); + constexpr auto src_vector_desc = + make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); // access order and lengths constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; @@ -108,31 +107,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward iterators - const auto src_forward_iterators = generate_tuple( + // make forward steps + const auto src_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; }); - return make_dynamic_tensor_coordinate_iterator( - src_desc, forward_step, src_iterator_hacks[I0][i]); + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto src_backward_iterators = generate_tuple( + // make backward steps + const auto src_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; }); - return make_dynamic_tensor_coordinate_iterator( - src_desc, backward_step, src_iterator_hacks[I1][i]); + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); }, Number{}); @@ -219,13 +218,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 { if constexpr(forward_sweep[i]) { - move_dynamic_tensor_coordinate( - src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); } else { - move_dynamic_tensor_coordinate( - src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); } } }); @@ -234,17 +233,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 // move src coordinate back to slice origin (or not) if constexpr(SrcResetCoordinateAfterRun) { - const auto src_reset_iterator = - make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); } } - template - __device__ void RunWrite(const DstDesc& dst_desc, - DstBuffer& dst_buf, - const DstIteratorHacks& dst_iterator_hacks) + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) { static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, @@ -265,9 +263,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 I1), DstVectorTensorContiguousDimOrder{}); - constexpr auto dst_vector_desc = make_dynamic_naive_tensor_descriptor_v2( - sequence_to_tuple_of_number(dst_vector_tensor_lengths), - sequence_to_tuple_of_number(dst_vector_tensor_strides)); + constexpr auto dst_vector_desc = + make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(dst_vector_tensor_lengths), + sequence_to_tuple_of_number(dst_vector_tensor_strides)); // dst access order and lengths constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; @@ -277,35 +275,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward iterators - const auto dst_forward_iterators = generate_tuple( + // make forward steps + const auto dst_forward_steps = generate_tuple( [&](auto i) { - Index forward_step; + Index forward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; + forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; }); - const auto forward_iterator = make_dynamic_tensor_coordinate_iterator( - dst_desc, forward_step, dst_iterator_hacks[I0][i]); - - return forward_iterator; + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); }, Number{}); - // make backward iterators - const auto dst_backward_iterators = generate_tuple( + // make backward steps + const auto dst_backward_steps = generate_tuple( [&](auto i) { - Index backward_step; + Index backward_step_idx; static_for<0, nDim, 1>{}([&](auto j) { - backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; + backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; }); - const auto backward_iterator = make_dynamic_tensor_coordinate_iterator( - dst_desc, backward_step, dst_iterator_hacks[I1][i]); - - return backward_iterator; + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); }, Number{}); @@ -394,13 +388,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 { if constexpr(forward_sweep[i]) { - move_dynamic_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); } else { - move_dynamic_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); } } }); @@ -409,10 +403,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 // move dst coordinate back to slice origin (or not) if constexpr(DstResetCoordinateAfterRun) { - const auto dst_reset_iterator = - make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); } } @@ -423,11 +417,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto src_iterator_hacks = + constexpr auto src_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - RunRead(src_desc, src_buf, src_iterator_hacks); + RunRead(src_desc, src_buf, src_step_hacks); } template @@ -437,11 +431,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 constexpr auto zeros = typename uniform_sequence_gen::type{}; - constexpr auto dst_iterator_hacks = + constexpr auto dst_step_hacks = make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), generate_tuple([&](auto) { return zeros; }, Number{})); - RunWrite(dst_desc, dst_buf, dst_iterator_hacks); + RunWrite(dst_desc, dst_buf, dst_step_hacks); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -564,18 +558,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = - make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason - template + template __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx, - const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) { // if src coord was not reset by RunRead(), then need to adjust the step here const auto adjusted_step_idx = @@ -583,10 +576,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( - src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); - move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, @@ -598,19 +591,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); // is it OK to construct a new step every time? - const auto adjusted_step = - make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); } private: static constexpr auto buffer_desc_ = - make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{})); + make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); - StaticBuffer buffer_; + StaticBuffer buffer_; SrcCoord src_coord_; DstCoord dst_coord_; @@ -622,25 +614,24 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1 // 2. SrcBuffer is DynamicBuffer // 3. src_ref_idx is known at run-time // 4. SrcRefToOriginDisplacement is known at compile-time -// 5. use #-iterator +// 5. use #-step // 2. dst: // 1. DstDesc is known at compile-time // 2. DstBuffer is StaticBuffer // 3. DstOriginIdx is known at compile-time // 4. use direct address calculation // 3. vector access on src -template < - typename SrcData, - typename DstData, - typename SrcDesc, - typename DstDesc, - typename SliceLengths, - typename DimAccessOrder, - typename SrcVectorTensorLengths, - typename SrcVectorTensorContiguousDimOrder, - typename std::enable_if::type = false> -struct ThreadwiseDynamicTensorSliceTransfer_v4r1 +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4r1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -649,12 +640,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1 using Index = MultiIndex; - using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4r1(const Index& src_ref_idx) - : src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) + __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) { static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); @@ -712,9 +703,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1 I1), SrcVectorTensorContiguousDimOrder{}); - constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2( - sequence_to_tuple_of_number(src_vector_tensor_lengths), - sequence_to_tuple_of_number(src_vector_tensor_strides)); + constexpr auto src_vector_desc = + make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); // access order and lengths constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; @@ -734,13 +725,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1 constexpr auto src_ref_to_data_disp_idx = src_ref_to_origin_disp_idx + data_to_origin_disp_idx; - constexpr auto src_ref_to_data_disp_coord_iterator = - make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); auto src_data_coord = src_ref_coord_; - move_dynamic_tensor_coordinate( - src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); vector_type_maker_t src_vector; @@ -775,10 +765,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1 { constexpr auto src_desc = SrcDesc{}; - const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator( - src_desc, to_multi_index(src_slice_move_step_idx)); + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); - move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); } private: diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index 876a1174e7..affe096ace 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -350,8 +350,8 @@ struct mfma_info class FloatC> __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const { - const auto p_a = reinterpret_cast(a); - const auto p_b = reinterpret_cast(b); + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); return intrin_mfma_f32_32x32x2bf16::run( p_a, p_b, reg_c); @@ -384,8 +384,8 @@ struct mfma_info class FloatC> __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const { - const auto p_a = reinterpret_cast(a); - const auto p_b = reinterpret_cast(b); + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); } @@ -417,8 +417,8 @@ struct mfma_info class FloatC> __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const { - const auto p_a = reinterpret_cast(a); - const auto p_b = reinterpret_cast(b); + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); } @@ -450,8 +450,8 @@ struct mfma_info class FloatC> __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const { - const auto p_a = reinterpret_cast(a); - const auto p_b = reinterpret_cast(b); + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); return intrin_mfma_f32_16x16x2bf16(p_a, p_b, reg_c); } @@ -483,8 +483,8 @@ struct mfma_info class FloatC> __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const { - const auto p_a = reinterpret_cast(a); - const auto p_b = reinterpret_cast(b); + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); return intrin_mfma_f32_4x4x2bf16::run(p_a, p_b, reg_c); } diff --git a/composable_kernel/include/utility/amd_address_space.hpp b/composable_kernel/include/utility/amd_address_space.hpp new file mode 100644 index 0000000000..24c95b27af --- /dev/null +++ b/composable_kernel/include/utility/amd_address_space.hpp @@ -0,0 +1,44 @@ +#ifndef CK_AMD_ADDRESS_SPACE_HPP +#define CK_AMD_ADDRESS_SPACE_HPP + +#include "config.hpp" +#include "c_style_pointer_cast.hpp" + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +namespace ck { + +enum AddressSpaceEnum_t +{ + Generic, + Global, + Lds, + Sgpr, + Vgpr, +}; + +template +__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p) +{ + // cast a pointer in "Constant" address space (4) to "Generic" address space (0) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +template +__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p) +{ + // cast a pointer in "Generic" address space (0) to "Constant" address space (4) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T CONSTANT*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp similarity index 87% rename from composable_kernel/include/utility/amd_buffer_addressing_v2.hpp rename to composable_kernel/include/utility/amd_buffer_addressing.hpp index 0139bceb61..57081b7fd7 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -1,34 +1,34 @@ -#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP -#define CK_AMD_BUFFER_ADDRESSING_V2_HPP +#ifndef CK_AMD_BUFFER_ADDRESSING_HPP +#define CK_AMD_BUFFER_ADDRESSING_HPP #include "data_type.hpp" namespace ck { template -union BufferResource_v2 +union BufferResource { // 128 bit SGPRs to supply buffer resource in buffer instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions - int32x4_t data; + int32x4_t content; StaticallyIndexedArray address; StaticallyIndexedArray range; StaticallyIndexedArray config; }; template -__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size) +__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) { - BufferResource_v2 wave_buffer_resource; + BufferResource wave_buffer_resource; // wavewise base address (64 bit) wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); // wavewise range (32 bit) - wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T); + wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T); // wavewise setting (32 bit) wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; - return wave_buffer_resource.data; + return wave_buffer_resource.content; } // load @@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); template -__device__ typename vector_type::type -amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) +__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { static_assert( (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || @@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, } template -__device__ void amd_buffer_store_impl_v2(const typename vector_type::type src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) +__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) { static_assert( (is_same::value && (N == 1 || N == 2 || N == 4)) || @@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type::type // buffer_load requires: // 1) p_src_wave must be in global memory space -// 2) p_src_wave to be a wavewise pointer. +// 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ typename vector_type_maker::type::type -amd_buffer_load_v2(const T* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_element_space) +amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) { const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space); + make_wave_buffer_resource(p_src_wave, src_element_space_size); - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T); + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; - using vector_t = typename vector_type_maker::type::type; - using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; + uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff; - return amd_buffer_load_impl_v2( + return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else - vector_t tmp = amd_buffer_load_impl_v2( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_data_valid ? tmp : vector_t(0); + return src_thread_element_valid ? tmp : vector_t(0); #endif } +// buffer_load requires: +// 1) p_src_wave must be in global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size, + T customized_value) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + + return src_thread_element_valid ? tmp : vector_t(customized_value); +} + // buffer_store requires: // 1) p_dst_wave must be global memory // 2) p_dst_wave to be a wavewise pointer. // It is user's responsibility to make sure that is true. template -__device__ void -amd_buffer_store_v2(const typename vector_type_maker::type::type src_thread_data, - T* p_dst_wave, - const index_t dst_thread_data_offset, - const bool dst_thread_data_valid, - const index_t dst_element_space) +__device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space); + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T); + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; - amd_buffer_store_impl_v2( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else - if(dst_thread_data_valid) + if(dst_thread_element_valid) { - amd_buffer_store_impl_v2( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif diff --git a/composable_kernel/include/utility/amd_dlop.hpp b/composable_kernel/include/utility/amd_dlop.hpp deleted file mode 100644 index 8ce19012e9..0000000000 --- a/composable_kernel/include/utility/amd_dlop.hpp +++ /dev/null @@ -1,188 +0,0 @@ -#ifndef CK_AMD_DLOP_HPP -#define CK_AMD_DLOP_HPP - -#include "data_type.hpp" - -namespace ck { - -template -__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c); - -template <> -__device__ void -amd_inner_product_dlop(const float& a, const float& b, float& c) -{ -#if CK_USE_AMD_DLOP_INLINE_ASM - asm volatile("\n \ - v_fmac_f32 %0, %1, %2 \n \ - " - : "=v"(c) - : "v"(a), "v"(b), "0"(c)); -#else - c += a * b; -#endif -} - -template <> -__device__ void -amd_inner_product_dlop(const float2_t& a, const float2_t& b, float& c) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - amd_inner_product_dlop(vector_type{a}.AsType()[I0], - vector_type{b}.AsType()[I0], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I1], - vector_type{b}.AsType()[I1], - c); -} - -template <> -__device__ void -amd_inner_product_dlop(const float4_t& a, const float4_t& b, float& c) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - amd_inner_product_dlop(vector_type{a}.AsType()[I0], - vector_type{b}.AsType()[I0], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I1], - vector_type{b}.AsType()[I1], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I2], - vector_type{b}.AsType()[I2], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I3], - vector_type{b}.AsType()[I3], - c); -} - -#if CK_USE_AMD_DLOP -template <> -__device__ void -amd_inner_product_dlop(const half2_t& a, const half2_t& b, float& c) -{ -#if CK_USE_AMD_DLOP_INLINE_ASM - asm volatile("\n \ - v_dot2_f32_f16 %0, %1, %2, %0\n \ - " - : "=v"(c) - : "v"(a), "v"(b), "0"(c)); -#else - c = __builtin_amdgcn_sdot2(a, b, c, false); -#endif -} - -template <> -__device__ void -amd_inner_product_dlop(const half4_t& a, const half4_t& b, float& c) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - amd_inner_product_dlop(vector_type{a}.AsType()[I0], - vector_type{b}.AsType()[I0], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I1], - vector_type{b}.AsType()[I1], - c); -} - -template <> -__device__ void -amd_inner_product_dlop(const half8_t& a, const half8_t& b, float& c) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - amd_inner_product_dlop(vector_type{a}.AsType()[I0], - vector_type{b}.AsType()[I0], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I1], - vector_type{b}.AsType()[I1], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I2], - vector_type{b}.AsType()[I2], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I3], - vector_type{b}.AsType()[I3], - c); -} - -template <> -__device__ void amd_inner_product_dlop(const int8x4_t& a, - const int8x4_t& b, - int32_t& c) -{ -#if CK_USE_AMD_DLOP_INLINE_ASM - asm volatile("\n \ - v_dot4_i32_i8 %0, %1, %2, %0\n \ - " - : "=v"(c) - : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); -#else - c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); -#endif -} - -template <> -__device__ void amd_inner_product_dlop(const int8x8_t& a, - const int8x8_t& b, - int32_t& c) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - amd_inner_product_dlop(vector_type{a}.AsType()[I0], - vector_type{b}.AsType()[I0], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I1], - vector_type{b}.AsType()[I1], - c); -} - -template <> -__device__ void amd_inner_product_dlop(const int8x16_t& a, - const int8x16_t& b, - int32_t& c) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - amd_inner_product_dlop(vector_type{a}.AsType()[I0], - vector_type{b}.AsType()[I0], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I1], - vector_type{b}.AsType()[I1], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I2], - vector_type{b}.AsType()[I2], - c); - - amd_inner_product_dlop(vector_type{a}.AsType()[I3], - vector_type{b}.AsType()[I3], - c); -} -#endif // CK_USE_AMD_DLOP - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index ce80fc0549..a2d9d5f062 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -2,6 +2,9 @@ #define CK_AMD_INLINE_ASM_HPP #include "data_type.hpp" +#include "c_style_pointer_cast.hpp" + +// TODO: deprecate all amd_assembly_outer_product_xxx namespace ck { @@ -53,9 +56,9 @@ __device__ void amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) { // TODO remove pointer casting - const half2_t* p_a_half2 = reinterpret_cast(&a); - const half2_t* p_b0_half2 = reinterpret_cast(&b0); - const half2_t* p_b1_half2 = reinterpret_cast(&b1); + const half2_t* p_a_half2 = c_style_pointer_cast(&a); + const half2_t* p_b0_half2 = c_style_pointer_cast(&b0); + const half2_t* p_b1_half2 = c_style_pointer_cast(&b1); // do dot2 two times asm volatile("\n \ @@ -114,11 +117,11 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, float& c3) { // TODO remove pointer casting - const half2_t* p_a_half2 = reinterpret_cast(&a); - const half2_t* p_b0_half2 = reinterpret_cast(&b0); - const half2_t* p_b1_half2 = reinterpret_cast(&b1); - const half2_t* p_b2_half2 = reinterpret_cast(&b2); - const half2_t* p_b3_half2 = reinterpret_cast(&b3); + const half2_t* p_a_half2 = c_style_pointer_cast(&a); + const half2_t* p_b0_half2 = c_style_pointer_cast(&b0); + const half2_t* p_b1_half2 = c_style_pointer_cast(&b1); + const half2_t* p_b2_half2 = c_style_pointer_cast(&b2); + const half2_t* p_b3_half2 = c_style_pointer_cast(&b3); // do dot2 two times asm volatile("\n \ @@ -160,11 +163,11 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a, { // TODO remove pointer casting - const half4_t* p_a_half4 = reinterpret_cast(&a); - const half4_t* p_b0_half4 = reinterpret_cast(&b0); - const half4_t* p_b1_half4 = reinterpret_cast(&b1); - const half4_t* p_b2_half4 = reinterpret_cast(&b2); - const half4_t* p_b3_half4 = reinterpret_cast(&b3); + const half4_t* p_a_half4 = c_style_pointer_cast(&a); + const half4_t* p_b0_half4 = c_style_pointer_cast(&b0); + const half4_t* p_b1_half4 = c_style_pointer_cast(&b1); + const half4_t* p_b2_half4 = c_style_pointer_cast(&b2); + const half4_t* p_b3_half4 = c_style_pointer_cast(&b3); amd_assembly_outer_product_1x4( p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3); @@ -184,11 +187,11 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a, float& c3) { // TODO remove pointer casting - const half8_t* p_a_half8 = reinterpret_cast(&a); - const half8_t* p_b0_half8 = reinterpret_cast(&b0); - const half8_t* p_b1_half8 = reinterpret_cast(&b1); - const half8_t* p_b2_half8 = reinterpret_cast(&b2); - const half8_t* p_b3_half8 = reinterpret_cast(&b3); + const half8_t* p_a_half8 = c_style_pointer_cast(&a); + const half8_t* p_b0_half8 = c_style_pointer_cast(&b0); + const half8_t* p_b1_half8 = c_style_pointer_cast(&b1); + const half8_t* p_b2_half8 = c_style_pointer_cast(&b2); + const half8_t* p_b3_half8 = c_style_pointer_cast(&b3); amd_assembly_outer_product_1x4( p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3); diff --git a/composable_kernel/include/utility/c_style_pointer_cast.hpp b/composable_kernel/include/utility/c_style_pointer_cast.hpp new file mode 100644 index 0000000000..8acf5790c6 --- /dev/null +++ b/composable_kernel/include/utility/c_style_pointer_cast.hpp @@ -0,0 +1,22 @@ +#ifndef CK_C_STYLE_POINTER_CAST_HPP +#define CK_C_STYLE_POINTER_CAST_HPP + +#include "type.hpp" +#include "enable_if.hpp" + +namespace ck { + +template && is_pointer_v, bool>::type = false> +__host__ __device__ PY c_style_pointer_cast(PX p_x) +{ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" +#pragma clang diagnostic ignored "-Wcast-align" + return (PY)p_x; // NOLINT(old-style-cast, cast-align) +#pragma clang diagnostic pop +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 5ff7688a1c..85c02a1b99 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -7,13 +7,14 @@ #include "statically_indexed_array.hpp" #include "container_element_picker.hpp" #include "multi_index.hpp" -#include "data_type_enum.hpp" #include "data_type.hpp" -#include "data_type_helper.hpp" +#include "data_type_enum.hpp" +#include "data_type_enum_helper.hpp" #include "functional.hpp" #include "functional2.hpp" #include "functional3.hpp" #include "functional4.hpp" +#include "enable_if.hpp" #include "integral_constant.hpp" #include "math.hpp" #include "number.hpp" @@ -23,21 +24,21 @@ #include "tuple.hpp" #include "tuple_helper.hpp" #include "type.hpp" -#include "utility.hpp" #include "magic_division.hpp" -#include "amd_buffer_addressing_v2.hpp" +#include "utility.hpp" +#include "c_style_pointer_cast.hpp" +#include "amd_address_space.hpp" +#include "amd_buffer_addressing.hpp" #include "static_buffer.hpp" #include "dynamic_buffer.hpp" +#include "inner_product.hpp" + // TODO: remove this #if CK_USE_AMD_INLINE_ASM #include "amd_inline_asm.hpp" #endif -#if CK_USE_AMD_DLOP -#include "amd_dlop.hpp" -#endif - #if CK_USE_AMD_XDLOPS #include "amd_xdlops.hpp" #endif diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 4908d8d818..521ad24d47 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -7,19 +7,14 @@ #endif #include "bfloat16_dev.hpp" -// address space for kernel parameter +// "Constant" address space for kernel parameter #define CONSTANT __attribute__((address_space(4))) // GPU target // should enable one and only one GPU target #if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030)) -#error Need to define a single GPU target -#endif - -// HIP version -#ifndef CK_HIP_VERSION_FLAT -#define CK_HIP_VERSION_FLAT 0 +#error Need to define (only) one GPU target #endif // launch bounds @@ -38,6 +33,16 @@ #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #endif +// FMA instruction +#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) +#define CK_USE_AMD_V_MAC_F32 +#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \ + defined(CK_AMD_GPU_GFX1030) +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8 +#endif + // multi index #define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 @@ -46,13 +51,9 @@ #define CK_USE_AMD_INLINE_ASM 1 #endif -// AMD DLOPS -#ifndef CK_USE_AMD_DLOP -#define CK_USE_AMD_DLOP 1 -#endif - -#ifndef CK_USE_AMD_DLOP_INLINE_ASM -#define CK_USE_AMD_DLOP_INLINE_ASM 1 +// AMD inner product (DLOP) +#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM +#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 #endif // AMD buffer addressing @@ -99,8 +100,8 @@ // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // thread-invariant, otherwise it's a bug // TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" -#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE -#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 +#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE +#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 #endif // workaround for compiler crash when compiling recursive lambda @@ -120,15 +121,6 @@ namespace ck { -enum AddressSpaceEnum_t -{ - Generic, - Global, - Lds, - Sgpr, - Vgpr -}; - enum InMemoryDataOperationEnum_t { Set, diff --git a/composable_kernel/include/utility/data_type_enum.hpp b/composable_kernel/include/utility/data_type_enum.hpp index 43499605dc..35df0067a9 100644 --- a/composable_kernel/include/utility/data_type_enum.hpp +++ b/composable_kernel/include/utility/data_type_enum.hpp @@ -3,8 +3,7 @@ namespace ck { -// this enumerate should be synchronized with include/miopen.h -typedef enum +enum DataTypeEnum_t { Half = 0, Float = 1, @@ -14,7 +13,7 @@ typedef enum BFloat16 = 5, Double = 6, Unknown = 100, -} DataTypeEnum_t; +}; } // namespace ck #endif diff --git a/composable_kernel/include/utility/data_type_helper.hpp b/composable_kernel/include/utility/data_type_enum_helper.hpp similarity index 94% rename from composable_kernel/include/utility/data_type_helper.hpp rename to composable_kernel/include/utility/data_type_enum_helper.hpp index 6a234cd10b..451ce992b1 100644 --- a/composable_kernel/include/utility/data_type_helper.hpp +++ b/composable_kernel/include/utility/data_type_enum_helper.hpp @@ -1,5 +1,5 @@ -#ifndef CK_DATA_TYPE_HELPER_HPP -#define CK_DATA_TYPE_HELPER_HPP +#ifndef CK_DATA_TYPE_ENUM_HELPER_HPP +#define CK_DATA_TYPE_ENUM_HELPER_HPP #include "data_type.hpp" #include "data_type_enum.hpp" diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index 5f5f386306..4d583e3ce7 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -1,38 +1,49 @@ -#ifndef CK_DYNAMIC_BUFFER_HPP -#define CK_DYNAMIC_BUFFER_HPP +#ifndef CK_BUFFER_HPP +#define CK_BUFFER_HPP + +#include "amd_buffer_addressing.hpp" +#include "c_style_pointer_cast.hpp" +#include "enable_if.hpp" namespace ck { -#include "amd_buffer_addressing_v2.hpp" - -template +template struct DynamicBuffer { using type = T; T* p_data_; ElementSpaceSize element_space_size_; + T invalid_element_value_ = T{0}; __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) : p_data_{p_data}, element_space_size_{element_space_size} { } + __host__ __device__ constexpr DynamicBuffer(T* p_data, + ElementSpaceSize element_space_size, + T invalid_element_value) + : p_data_{p_data}, + element_space_size_{element_space_size}, + invalid_element_value_{invalid_element_value} + { + } + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() { return BufferAddressSpace; } - __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } - - __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } - template >>::type, typename scalar_type>>::type>::value, bool>::type = false> - __host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const + __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { // X contains multiple T constexpr index_t scalar_per_t_vector = @@ -44,29 +55,50 @@ struct DynamicBuffer static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); - constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - - if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) - { #if CK_USE_AMD_BUFFER_ADDRESSING - return amd_buffer_load_v2>, t_per_x>( - p_data_, i, is_valid_offset, element_space_size_); + bool constexpr use_amd_buffer_addressing = true; #else - return is_valid_offset ? *reinterpret_cast(&p_data_[i]) : X{0}; + bool constexpr use_amd_buffer_addressing = false; #endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_return_zero< + remove_cv_t>, + t_per_x>(p_data_, i, is_valid_element, element_space_size_); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value< + remove_cv_t>, + t_per_x>( + p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); + } } else { - return is_valid_offset ? *reinterpret_cast(&p_data_[i]) : X{0}; + if constexpr(InvalidElementUseNumericalZeroValue) + { + return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) : X{0}; + } + else + { + return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) + : X{invalid_element_value_}; + } } } template >>::type, typename scalar_type>>::type>::value, bool>::type = false> - __host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x) + __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = @@ -78,26 +110,26 @@ struct DynamicBuffer static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); - constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) { #if CK_USE_AMD_BUFFER_ADDRESSING - amd_buffer_store_v2>, t_per_x>( - x, p_data_, i, is_valid_offset, element_space_size_); + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store>, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); #else - if(is_valid_offset) + if(is_valid_element) { - *reinterpret_cast(&p_data_[i]) = x; + *c_style_pointer_cast(&p_data_[i]) = x; } #endif } else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) { - if(is_valid_offset) + if(is_valid_element) { #if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE - *reinterpret_cast(&p_data_[i]) = x; + *c_style_pointer_cast(&p_data_[i]) = x; #else // HACK: compiler would lower IR "store address_space(3)" into // inefficient @@ -128,24 +160,24 @@ struct DynamicBuffer { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix - *reinterpret_cast(&p_data_[i]) = - *reinterpret_cast(&x); + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8_t>::value && is_same>, int8x2_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix - *reinterpret_cast(&p_data_[i]) = - *reinterpret_cast(&x); + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8_t>::value && is_same>, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix - *reinterpret_cast(&p_data_[i]) = - *reinterpret_cast(&x); + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8x4_t>::value && @@ -153,8 +185,8 @@ struct DynamicBuffer { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix - *reinterpret_cast(&p_data_[i]) = - *reinterpret_cast(&x); + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8x8_t>::value && @@ -162,8 +194,8 @@ struct DynamicBuffer { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix - *reinterpret_cast(&p_data_[i]) = - *reinterpret_cast(&x); + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); } else if constexpr(is_same>, int8x16_t>::value && @@ -171,22 +203,22 @@ struct DynamicBuffer { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix - *reinterpret_cast(&p_data_[i]) = - *reinterpret_cast(&x); + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); } } else { - *reinterpret_cast(&p_data_[i]) = x; + *c_style_pointer_cast(&p_data_[i]) = x; } #endif } } else { - if(is_valid_offset) + if(is_valid_element) { - *reinterpret_cast(&p_data_[i]) = x; + *c_style_pointer_cast(&p_data_[i]) = x; } } } @@ -196,12 +228,18 @@ struct DynamicBuffer __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } }; -template +template __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) { - return DynamicBuffer{p, element_space_size}; + return DynamicBuffer{p, element_space_size}; +} + +template +__host__ __device__ constexpr auto +make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value) +{ + return DynamicBuffer{ + p, element_space_size, invalid_element_value}; } } // namespace ck diff --git a/composable_kernel/include/utility/enable_if.hpp b/composable_kernel/include/utility/enable_if.hpp new file mode 100644 index 0000000000..501e1bfc1c --- /dev/null +++ b/composable_kernel/include/utility/enable_if.hpp @@ -0,0 +1,13 @@ +#ifndef CK_ENABLE_IF_HPP +#define CK_ENABLE_IF_HPP + +namespace ck { + +template +using enable_if = std::enable_if; + +template +using enable_if_t = typename std::enable_if::type; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/inner_product.hpp b/composable_kernel/include/utility/inner_product.hpp new file mode 100644 index 0000000000..51753accf3 --- /dev/null +++ b/composable_kernel/include/utility/inner_product.hpp @@ -0,0 +1,207 @@ +#ifndef CK_INNER_PRODUCT_HPP +#define CK_INNER_PRODUCT_HPP + +#include "data_type.hpp" + +namespace ck { + +template +__device__ void inner_product(const TA& a, const TB& b, TC& c); + +template <> +__device__ void inner_product(const float& a, const float& b, float& c) +{ +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) + asm volatile("\n \ + v_mac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) + asm volatile("\n \ + v_fmac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c += a * b; +#endif +} + +template <> +__device__ void +inner_product(const float2_t& a, const float2_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const float4_t& a, const float4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void inner_product(const half2_t& a, const half2_t& b, float& c) +{ +#if defined(CK_USE_AMD_V_DOT2_F32_F16) +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM + asm volatile("\n \ + v_dot2_f32_f16 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c = __builtin_amdgcn_sdot2(a, b, c, false); +#endif +#else + const auto convert = type_convert{}; + + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 2, 1>{}([&](auto i) { + c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); +#endif +} + +template <> +__device__ void inner_product(const half4_t& a, const half4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void inner_product(const half8_t& a, const half8_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void +inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) +{ +#if defined(CK_USE_DOT4_I32_I8) +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM + asm volatile("\n \ + v_dot4_i32_i8 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); +#else + c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); +#endif +#else + const auto convert = type_convert{}; + + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 4, 1>{}([&](auto i) { + c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); +#endif +} + +template <> +__device__ void +inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index e451059647..bcb25a2941 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -5,6 +5,7 @@ #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" +#include "enable_if.hpp" namespace ck { namespace math { @@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number, Number) return Number{}; } -template = 2, bool>::type = false> +template = 2, bool>::type = false> __host__ __device__ constexpr auto gcd(X x, Ys... ys) { return gcd(x, gcd(ys...)); @@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y) return (x * y) / gcd(x, y); } -template = 2, bool>::type = false> +template = 2, bool>::type = false> __host__ __device__ constexpr auto lcm(X x, Ys... ys) { return lcm(x, lcm(ys...)); diff --git a/composable_kernel/include/utility/print.hpp b/composable_kernel/include/utility/print.hpp index 0dd646153a..d7d58bbb83 100644 --- a/composable_kernel/include/utility/print.hpp +++ b/composable_kernel/include/utility/print.hpp @@ -11,59 +11,11 @@ namespace ck { template __host__ __device__ void print_array(const char* s, T a) { - using data_type = decltype(a.At(Number<0>{})); constexpr index_t nsize = a.Size(); -#if 0 - if constexpr(is_same{}) - { - printf("%s size %u, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); }); - printf("}\n"); - } - else if constexpr(is_same{}) - { - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); - printf("}\n"); - } - else if constexpr(is_same{}) - { - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); }); - printf("}\n"); - } -#else printf("%s size %d, {", s, nsize); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); printf("}\n"); -#endif -} - -template -__host__ __device__ void print_array_v2(const char* s, T a) -{ - using data_type = decltype(a.At(Number<0>{})); - constexpr index_t nsize = a.Size(); - -#if 0 - if constexpr(is_same{}) - { - printf("%s size %u, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); }); - printf("}\n"); - } - else if constexpr(is_same{}) - { - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); - printf("}\n"); - } -#else - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); - printf("}\n"); -#endif } } // namespace ck diff --git a/composable_kernel/include/utility/sequence.hpp b/composable_kernel/include/utility/sequence.hpp index 81eb488715..b35999d56f 100644 --- a/composable_kernel/include/utility/sequence.hpp +++ b/composable_kernel/include/utility/sequence.hpp @@ -685,8 +685,6 @@ __host__ __device__ constexpr auto operator+(Number, Sequence) template __host__ __device__ constexpr auto operator-(Number, Sequence) { - constexpr auto seq_x = Sequence{}; - return Sequence<(Y - Xs)...>{}; } diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp index a23cf4f80d..cd67b8a0be 100644 --- a/composable_kernel/include/utility/static_buffer.hpp +++ b/composable_kernel/include/utility/static_buffer.hpp @@ -5,30 +5,66 @@ namespace ck { -template +template struct StaticBuffer : public StaticallyIndexedArray { using type = T; using base = StaticallyIndexedArray; + T invalid_element_value_ = T{0}; + __host__ __device__ constexpr StaticBuffer() : base{} {} + __host__ __device__ constexpr StaticBuffer(T invalid_element_value) + : base{}, invalid_element_value_{invalid_element_value} + { + } + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() { return BufferAddressSpace; } + template + __host__ __device__ constexpr auto Get(Number i, bool is_valid_element) const + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return is_valid_element ? At(i) : T{0}; + } + else + { + return is_valid_element ? At(i) : invalid_element_value_; + } + } + + template + __host__ __device__ void Set(Number i, bool is_valid_element, const T& x) + { + if(is_valid_element) + { + At(i) = x; + } + } + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } }; -template +template __host__ __device__ constexpr auto make_static_buffer(Number) { - return StaticBuffer{}; + return StaticBuffer{}; +} + +template +__host__ __device__ constexpr auto make_static_buffer(Number, T invalid_element_value) +{ + return StaticBuffer{invalid_element_value}; } } // namespace ck diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index 15b73011b4..ee96a8b435 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -4,6 +4,7 @@ #include "integral_constant.hpp" #include "sequence.hpp" #include "type.hpp" +#include "enable_if.hpp" namespace ck { @@ -20,10 +21,9 @@ struct TupleElement { __host__ __device__ constexpr TupleElement() = default; - template < - typename T, - typename std::enable_if>, TupleElement>::value, - bool>::type = false> + template >, TupleElement>::value, + bool>::type = false> __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) { } @@ -58,17 +58,16 @@ struct TupleImpl, Xs...> : TupleElement, Xs> { __host__ __device__ constexpr TupleImpl() = default; - template < - typename Y, - typename std::enable_if>, TupleImpl>::value, - bool>::type = false> + template >, TupleImpl>::value, + bool>::type = false> __host__ __device__ constexpr TupleImpl(Y&& y) : TupleElement, Xs>(std::forward(y))... { } - template = 2, bool>::type = false> + template = 2, bool>::type = false> __host__ __device__ constexpr TupleImpl(Ys&&... ys) : TupleElement, Xs>(std::forward(ys))... { @@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl>, Tuple>::value, - bool>::type = false> + typename enable_if>, Tuple>::value, + bool>::type = false> __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) { } template = 2, - bool>::type = false> + typename enable_if= 2, bool>::type = + false> __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward(ys)...) { } diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp index 32f7dfb569..b7902ad496 100644 --- a/composable_kernel/include/utility/type.hpp +++ b/composable_kernel/include/utility/type.hpp @@ -2,6 +2,7 @@ #define CK_TYPE_HPP #include "integral_constant.hpp" +#include "enable_if.hpp" namespace ck { @@ -22,10 +23,7 @@ template using remove_cv_t = typename std::remove_cv::type; template -constexpr std::remove_reference_t&& move(T&& t) noexcept -{ - return static_cast::type&&>(t); -} +inline constexpr bool is_pointer_v = std::is_pointer::value; template struct is_known_at_compile_time; @@ -42,9 +40,7 @@ struct is_known_at_compile_time> static constexpr bool value = true; }; -template ::type = false> +template ::type = false> __host__ __device__ constexpr Y as_type(X x) { union AsType diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..09a7fffa3e --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,370 @@ +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v1r2.hpp" +#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; +constexpr index_t M1PerThread = CK_PARAM_M1PerThread; +constexpr index_t N1PerThread = CK_PARAM_N1PerThread; +constexpr index_t KPerThread = CK_PARAM_KPerThread; +constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10; +constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10; +constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11; +constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11; + +using ABlockTransferThreadSliceLengths_K_M0_M1 = + Sequence; +using ABlockTransferThreadClusterLengths_K_M0_M1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_M1 = + CK_PARAM_ABlockTransferDstScalarPerVector_M1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K_N0_N1 = + Sequence; +using BBlockTransferThreadClusterLengths_K_N0_N1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_N1 = + CK_PARAM_BBlockTransferDstScalarPerVector_N1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); +constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); + +extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( + int n, + int c, + int hi, + int wi, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k_m0_m1_grid_desc, + void* p_b_k_n0_n1_grid_desc, + void* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi)); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x)); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW)); + + const auto a_k_m_grid_desc = descs[I0]; + const auto b_k_n_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AKMGridDesc = decltype(a_k_m_grid_desc); + using BKNGridDesc = decltype(b_k_n_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}))); + + using BGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseGemmDlops_km_kn_mn_v1r2; + + auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc; + *static_cast(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc; + *static_cast( + p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + }; +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k_m0_m1_grid_desc, + const void CONSTANT* p_b_k_n0_n1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1)); + + constexpr auto a_k_m_grid_desc = descs[I0]; + constexpr auto b_k_n_grid_desc = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AKMGridDesc = decltype(a_k_m_grid_desc); + using BKNGridDesc = decltype(b_k_n_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}))); + + using BGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseGemmDlops_km_kn_mn_v1r2; + + constexpr auto a_k_m0_m1_grid_desc_tmp = + GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + constexpr auto b_k_n0_n1_grid_desc_tmp = + GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp); + using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp); + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k_m0_m1_grid_desc = + *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); + const auto b_k_n0_n1_grid_desc = + *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..51d852617f --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,358 @@ +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; + +constexpr index_t MPerWave = CK_PARAM_MPerWave; +constexpr index_t NPerWave = CK_PARAM_NPerWave; +constexpr index_t MRepeat = CK_PARAM_MRepeat; +constexpr index_t NRepeat = CK_PARAM_NRepeat; +constexpr index_t K1 = CK_PARAM_K1; + +using ABlockTransferThreadSliceLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_K1 = + CK_PARAM_ABlockTransferDstScalarPerVector_K1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_K1 = + CK_PARAM_BBlockTransferDstScalarPerVector_K1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( + int n, + int c, + int hi, + int wi, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k0_m_k1_grid_desc, + void* p_b_k0_n_k1_grid_desc, + void* p_c_m0_m1_m2_n_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi)); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x)); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW), + Number{}); + + const auto a_k0_m_k1_grid_desc = descs[I0]; + const auto b_k0_n_k1_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using BGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast*>(p_a_k0_m_k1_grid_desc) = + a_k0_m_k1_grid_desc; + *static_cast*>(p_b_k0_n_k1_grid_desc) = + b_k0_n_k1_grid_desc; + *static_cast(p_c_m0_m1_m2_n_grid_desc) = + c_m0_m1_m2_n_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); + + constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; + constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using BGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + constexpr auto c_m0_m1_m2_n_grid_desc_tmp = + GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); +}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp new file mode 100644 index 0000000000..30e4c518ce --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp @@ -0,0 +1,357 @@ +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; + +constexpr index_t MPerWave = CK_PARAM_MPerWave; +constexpr index_t NPerWave = CK_PARAM_NPerWave; +constexpr index_t MRepeat = CK_PARAM_MRepeat; +constexpr index_t NRepeat = CK_PARAM_NRepeat; +constexpr index_t K1 = CK_PARAM_K1; + +using ABlockTransferThreadSliceLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_K1 = + CK_PARAM_ABlockTransferDstScalarPerVector_K1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_K1 = + CK_PARAM_BBlockTransferDstScalarPerVector_K1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( + int n, + int hi, + int wi, + int c, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k0_m_k1_grid_desc, + void* p_b_k0_n_k1_grid_desc, + void* p_c_m0_m1_m2_n_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c)); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c)); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW), + Number{}); + + const auto a_k0_m_k1_grid_desc = descs[I0]; + const auto b_k0_n_k1_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using BGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using AGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast*>(p_a_k0_m_k1_grid_desc) = + a_k0_m_k1_grid_desc; + *static_cast*>(p_b_k0_n_k1_grid_desc) = + b_k0_n_k1_grid_desc; + *static_cast(p_c_m0_m1_m2_n_grid_desc) = + c_m0_m1_m2_n_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); + constexpr auto wei_k_y_x_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256)); + constexpr auto out_n_ho_wo_k_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); + + constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; + constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using BGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using AGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + constexpr auto c_m0_m1_m2_n_grid_desc_tmp = + GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); +}; diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp similarity index 87% rename from composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp rename to composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp index 90c957bb0b..c1208ac3cb 100644 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -1,7 +1,7 @@ #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_contraction_dlops_v1r2.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_contraction_dlops_v1r2.hpp" #include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" using namespace ck; @@ -62,23 +62,39 @@ constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HasMainKBloc constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HasDoubleTailKBlockLoop); extern "C" __global__ void -dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, - index_t C, - index_t Hi, - index_t Wi, - index_t K, - index_t Y, - index_t X, - index_t ConvStrideH, - index_t ConvStrideW, - index_t ConvDilationH, - index_t ConvDilationW, - index_t InLeftPadH, - index_t InLeftPadW, - index_t InRightPadH, - index_t InRightPadW, - void* p_desc_tuple) +convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_, + int C_, + int Hi_, + int Wi_, + int K_, + int Y_, + int X_, + int ConvStrideH_, + int ConvStrideW_, + int ConvDilationH_, + int ConvDilationW_, + int InLeftPadH_, + int InLeftPadW_, + int InRightPadH_, + int InRightPadW_, + void* p_desc_tuple) { + index_t N = static_cast(N_); + index_t C = static_cast(C_); + index_t Hi = static_cast(Hi_); + index_t Wi = static_cast(Wi_); + index_t K = static_cast(K_); + index_t Y = static_cast(Y_); + index_t X = static_cast(X_); + index_t ConvStrideH = static_cast(ConvStrideH_); + index_t ConvStrideW = static_cast(ConvStrideW_); + index_t ConvDilationH = static_cast(ConvDilationH_); + index_t ConvDilationW = static_cast(ConvDilationW_); + index_t InLeftPadH = static_cast(InLeftPadH_); + index_t InLeftPadW = static_cast(InLeftPadW_); + index_t InRightPadH = static_cast(InRightPadH_); + index_t InRightPadW = static_cast(InRightPadW_); + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -88,12 +104,9 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde const index_t Wo = (Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1; - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C, Hi, Wi)); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C, Y, X)); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi)); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X)); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo)); const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( wei_k_c_y_x_desc, @@ -114,7 +127,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); - using AGridIteratorHacks = + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 @@ -126,7 +139,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - using BGridIteratorHacks = decltype(make_tuple( + using BGridStepHacks = decltype(make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 @@ -138,7 +151,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - using CGridIteratorHacks = decltype(make_tuple( + using CGridStepHacks = decltype(make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 @@ -154,13 +167,13 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; using GridwiseContraction = - GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< BlockSize, FloatAB, FloatAcc, @@ -194,11 +207,11 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) { @@ -220,7 +233,7 @@ extern "C" __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -232,11 +245,11 @@ extern "C" __global__ void constexpr auto I3 = Number<3>{}; constexpr auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); constexpr auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); constexpr auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); constexpr auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, @@ -257,7 +270,7 @@ extern "C" __global__ void using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); - using AGridIteratorHacks = + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 @@ -269,7 +282,7 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - using BGridIteratorHacks = decltype(make_tuple( + using BGridStepHacks = decltype(make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 @@ -281,7 +294,7 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 - using CGridIteratorHacks = decltype(make_tuple( + using CGridStepHacks = decltype(make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 @@ -297,13 +310,13 @@ extern "C" __global__ void Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; using GridwiseContraction = - GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< BlockSize, FloatAB, FloatAcc, @@ -337,11 +350,11 @@ extern "C" __global__ void CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp deleted file mode 100644 index 652ccdb926..0000000000 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,374 +0,0 @@ -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_dlops_v1r2.hpp" -#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; -constexpr index_t M1PerThread = CK_PARAM_M1PerThread; -constexpr index_t N1PerThread = CK_PARAM_N1PerThread; -constexpr index_t KPerThread = CK_PARAM_KPerThread; -constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10; -constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10; -constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11; -constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11; - -using ABlockTransferThreadSliceLengths_K_M0_M1 = - Sequence; -using ABlockTransferThreadClusterLengths_K_M0_M1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_M1 = - CK_PARAM_ABlockTransferDstScalarPerVector_M1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K_N0_N1 = - Sequence; -using BBlockTransferThreadClusterLengths_K_N0_N1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_N1 = - CK_PARAM_BBlockTransferDstScalarPerVector_N1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); -constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); - -extern "C" __global__ void -dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( - int n, - int c, - int hi, - int wi, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k_m0_m1_grid_desc, - void* p_b_k_n0_n1_grid_desc, - void* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi)); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x)); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW)); - - const auto a_k_m_grid_desc = descs[I0]; - const auto b_k_n_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AKMGridDesc = decltype(a_k_m_grid_desc); - using BKNGridDesc = decltype(b_k_n_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}))); - - using BGridIteratorHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); - - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseDynamicGemmDlops_km_kn_mn_v1r2; - - auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc; - *static_cast(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc; - *static_cast( - p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - }; -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k_m0_m1_grid_desc, - const void CONSTANT* p_b_k_n0_n1_grid_desc, - const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1)); - - constexpr auto a_k_m_grid_desc = descs[I0]; - constexpr auto b_k_n_grid_desc = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AKMGridDesc = decltype(a_k_m_grid_desc); - using BKNGridDesc = decltype(b_k_n_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}))); - - using BGridIteratorHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); - - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseDynamicGemmDlops_km_kn_mn_v1r2; - - constexpr auto a_k_m0_m1_grid_desc_tmp = - GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - constexpr auto b_k_n0_n1_grid_desc_tmp = - GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp); - using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp); - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k_m0_m1_grid_desc = - *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); - const auto b_k_n0_n1_grid_desc = - *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - *reinterpret_cast( - (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -}; diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp deleted file mode 100644 index d33bc74aa6..0000000000 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,362 +0,0 @@ -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; - -constexpr index_t MPerWave = CK_PARAM_MPerWave; -constexpr index_t NPerWave = CK_PARAM_NPerWave; -constexpr index_t MRepeat = CK_PARAM_MRepeat; -constexpr index_t NRepeat = CK_PARAM_NRepeat; -constexpr index_t K1 = CK_PARAM_K1; - -using ABlockTransferThreadSliceLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_K1 = - CK_PARAM_ABlockTransferDstScalarPerVector_K1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_K1 = - CK_PARAM_BBlockTransferDstScalarPerVector_K1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -extern "C" __global__ void -dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( - int n, - int c, - int hi, - int wi, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k0_m_k1_grid_desc, - void* p_b_k0_n_k1_grid_desc, - void* p_c_m0_m1_m2_n_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi)); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x)); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW), - Number{}); - - const auto a_k0_m_k1_grid_desc = descs[I0]; - const auto b_k0_n_k1_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using AGridIteratorHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using BGridIteratorHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using GridwiseGemm = - GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast*>(p_a_k0_m_k1_grid_desc) = - a_k0_m_k1_grid_desc; - *static_cast*>(p_b_k0_n_k1_grid_desc) = - b_k0_n_k1_grid_desc; - *static_cast(p_c_m0_m1_m2_n_grid_desc) = - c_m0_m1_m2_n_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - } -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - Number{}); - - constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; - constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AGridIteratorHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using BGridIteratorHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using GridwiseGemm = - GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - constexpr auto c_m0_m1_m2_n_grid_desc_tmp = - GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k0_m_k1_grid_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); - const auto b_k0_n_k1_grid_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); - const auto c_m0_m1_m2_n_grid_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); -}; diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp deleted file mode 100644 index d49693b511..0000000000 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp +++ /dev/null @@ -1,362 +0,0 @@ -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" - -using namespace ck; - -constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); -constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); -constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); - -using FloatAB = typename get_datatype_from_enum::type; -using FloatAcc = typename get_datatype_from_enum::type; -using FloatC = typename get_datatype_from_enum::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; - -constexpr index_t MPerBlock = CK_PARAM_MPerBlock; -constexpr index_t NPerBlock = CK_PARAM_NPerBlock; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; - -constexpr index_t MPerWave = CK_PARAM_MPerWave; -constexpr index_t NPerWave = CK_PARAM_NPerWave; -constexpr index_t MRepeat = CK_PARAM_MRepeat; -constexpr index_t NRepeat = CK_PARAM_NRepeat; -constexpr index_t K1 = CK_PARAM_K1; - -using ABlockTransferThreadSliceLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterLengths_K0_M_K1 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_K1 = - CK_PARAM_ABlockTransferDstScalarPerVector_K1; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterLengths_K0_N_K1 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_K1 = - CK_PARAM_BBlockTransferDstScalarPerVector_K1; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -extern "C" __global__ void -dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( - int n, - int hi, - int wi, - int c, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_k0_m_k1_grid_desc, - void* p_b_k0_n_k1_grid_desc, - void* p_c_m0_m1_m2_n_grid_desc, - void* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, hi, wi, c)); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, y, x, c)); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, ho, wo, k)); - - const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( - in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW), - Number{}); - - const auto a_k0_m_k1_grid_desc = descs[I0]; - const auto b_k0_n_k1_grid_desc = descs[I1]; - const auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using BGridIteratorHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using AGridIteratorHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - - using GridwiseGemm = - GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - - auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast*>(p_a_k0_m_k1_grid_desc) = - a_k0_m_k1_grid_desc; - *static_cast*>(p_b_k0_n_k1_grid_desc) = - b_k0_n_k1_grid_desc; - *static_cast(p_c_m0_m1_m2_n_grid_desc) = - c_m0_m1_m2_n_grid_desc; - *static_cast( - p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; - } -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void CONSTANT* p_a_k0_m_k1_grid_desc, - const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, - const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) -{ - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256)); - constexpr auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 3, 3, 256)); - constexpr auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256)); - - constexpr auto descs = - transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, - wei_k_y_x_c_desc, - out_n_ho_wo_k_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - Number{}); - - constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; - constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; - constexpr auto c_m_n_grid_desc = descs[I2]; - - using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); - using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); - using CMNGridDesc = decltype(c_m_n_grid_desc); - - using BGridIteratorHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); - - using AGridIteratorHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}))); - - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>; - - using GridwiseGemm = - GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - constexpr auto c_m0_m1_m2_n_grid_desc_tmp = - GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = - GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); - - const auto a_k0_m_k1_grid_desc = - *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); - const auto b_k0_n_k1_grid_desc = - *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); - const auto c_m0_m1_m2_n_grid_desc = - *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); -}; diff --git a/external/half/include/half.hpp b/external/half/include/half.hpp deleted file mode 100644 index 25f543881f..0000000000 --- a/external/half/include/half.hpp +++ /dev/null @@ -1,5670 +0,0 @@ -// half - IEEE 754-based half-precision floating-point library. -// -// Copyright (c) 2012-2019 Christian Rau -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and -// associated documentation -// files (the "Software"), to deal in the Software without restriction, including without limitation -// the rights to use, copy, -// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit -// persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT -// NOT LIMITED TO THE -// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR -// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF -// CONTRACT, TORT OR OTHERWISE, -// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// Version 2.1.0 - -/// \file -/// Main header file for half-precision functionality. - -#ifndef HALF_HALF_HPP -#define HALF_HALF_HPP - -#define HALF_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) - -#if defined(__INTEL_COMPILER) -#define HALF_ICC_VERSION __INTEL_COMPILER -#elif defined(__ICC) -#define HALF_ICC_VERSION __ICC -#elif defined(__ICL) -#define HALF_ICC_VERSION __ICL -#else -#define HALF_ICC_VERSION 0 -#endif - -// check C++11 language features -#if defined(__clang__) // clang -#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) -#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 -#endif -#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) -#define HALF_ENABLE_CPP11_CONSTEXPR 1 -#endif -#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) -#define HALF_ENABLE_CPP11_NOEXCEPT 1 -#endif -#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) -#define HALF_ENABLE_CPP11_USER_LITERALS 1 -#endif -#if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) -#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 -#endif -#if(defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ - !defined(HALF_ENABLE_CPP11_LONG_LONG) -#define HALF_ENABLE_CPP11_LONG_LONG 1 -#endif -#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ -#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) -#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 -#endif -#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) -#define HALF_ENABLE_CPP11_USER_LITERALS 1 -#endif -#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) -#define HALF_ENABLE_CPP11_CONSTEXPR 1 -#endif -#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) -#define HALF_ENABLE_CPP11_NOEXCEPT 1 -#endif -#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) -#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 -#endif -#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) -#define HALF_ENABLE_CPP11_LONG_LONG 1 -#endif -#elif defined(__GNUC__) // gcc -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L -#if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) -#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 -#endif -#if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) -#define HALF_ENABLE_CPP11_USER_LITERALS 1 -#endif -#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) -#define HALF_ENABLE_CPP11_CONSTEXPR 1 -#endif -#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) -#define HALF_ENABLE_CPP11_NOEXCEPT 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) -#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 -#endif -#if !defined(HALF_ENABLE_CPP11_LONG_LONG) -#define HALF_ENABLE_CPP11_LONG_LONG 1 -#endif -#endif -#define HALF_TWOS_COMPLEMENT_INT 1 -#elif defined(_MSC_VER) // Visual C++ -#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) -#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 -#endif -#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) -#define HALF_ENABLE_CPP11_USER_LITERALS 1 -#endif -#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) -#define HALF_ENABLE_CPP11_CONSTEXPR 1 -#endif -#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) -#define HALF_ENABLE_CPP11_NOEXCEPT 1 -#endif -#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) -#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 -#endif -#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) -#define HALF_ENABLE_CPP11_LONG_LONG 1 -#endif -#define HALF_TWOS_COMPLEMENT_INT 1 -#define HALF_POP_WARNINGS 1 -#pragma warning(push) -#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if, negative unsigned -#endif - -// check C++11 library features -#include -#if defined(_LIBCPP_VERSION) // libc++ -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 -#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS -#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 -#endif -#ifndef HALF_ENABLE_CPP11_CSTDINT -#define HALF_ENABLE_CPP11_CSTDINT 1 -#endif -#ifndef HALF_ENABLE_CPP11_CMATH -#define HALF_ENABLE_CPP11_CMATH 1 -#endif -#ifndef HALF_ENABLE_CPP11_HASH -#define HALF_ENABLE_CPP11_HASH 1 -#endif -#ifndef HALF_ENABLE_CPP11_CFENV -#define HALF_ENABLE_CPP11_CFENV 1 -#endif -#endif -#elif defined(__GLIBCXX__) // libstdc++ -#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 -#ifdef __clang__ -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) -#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 -#endif -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) -#define HALF_ENABLE_CPP11_CSTDINT 1 -#endif -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) -#define HALF_ENABLE_CPP11_CMATH 1 -#endif -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) -#define HALF_ENABLE_CPP11_HASH 1 -#endif -#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) -#define HALF_ENABLE_CPP11_CFENV 1 -#endif -#else -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) -#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) -#define HALF_ENABLE_CPP11_CSTDINT 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) -#define HALF_ENABLE_CPP11_CMATH 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) -#define HALF_ENABLE_CPP11_HASH 1 -#endif -#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) -#define HALF_ENABLE_CPP11_CFENV 1 -#endif -#endif -#endif -#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ -#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) -#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 -#endif -#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) -#define HALF_ENABLE_CPP11_CSTDINT 1 -#endif -#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) -#define HALF_ENABLE_CPP11_HASH 1 -#endif -#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) -#define HALF_ENABLE_CPP11_CMATH 1 -#endif -#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) -#define HALF_ENABLE_CPP11_CFENV 1 -#endif -#endif -#undef HALF_GCC_VERSION -#undef HALF_ICC_VERSION - -// any error throwing C++ exceptions? -#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || \ - defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || \ - defined(HALF_ERRHANDLING_THROW_INEXACT) -#define HALF_ERRHANDLING_THROWS 1 -#endif - -// any error handling enabled? -#define HALF_ERRHANDLING \ - (HALF_ERRHANDLING_FLAGS || HALF_ERRHANDLING_ERRNO || HALF_ERRHANDLING_FENV || \ - HALF_ERRHANDLING_THROWS) - -#if HALF_ERRHANDLING -#define HALF_UNUSED_NOERR(name) name -#else -#define HALF_UNUSED_NOERR(name) -#endif - -// support constexpr -#if HALF_ENABLE_CPP11_CONSTEXPR -#define HALF_CONSTEXPR constexpr -#define HALF_CONSTEXPR_CONST constexpr -#if HALF_ERRHANDLING -#define HALF_CONSTEXPR_NOERR -#else -#define HALF_CONSTEXPR_NOERR constexpr -#endif -#else -#define HALF_CONSTEXPR -#define HALF_CONSTEXPR_CONST const -#define HALF_CONSTEXPR_NOERR -#endif - -// support noexcept -#if HALF_ENABLE_CPP11_NOEXCEPT -#define HALF_NOEXCEPT noexcept -#define HALF_NOTHROW noexcept -#else -#define HALF_NOEXCEPT -#define HALF_NOTHROW throw() -#endif - -// support thread storage -#if HALF_ENABLE_CPP11_THREAD_LOCAL -#define HALF_THREAD_LOCAL thread_local -#else -#define HALF_THREAD_LOCAL static -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#if HALF_ENABLE_CPP11_TYPE_TRAITS -#include -#endif -#if HALF_ENABLE_CPP11_CSTDINT -#include -#endif -#if HALF_ERRHANDLING_ERRNO -#include -#endif -#if HALF_ENABLE_CPP11_CFENV -#include -#endif -#if HALF_ENABLE_CPP11_HASH -#include -#endif -#if HALF_ENABLE_F16C_INTRINSICS -#include -#endif - -#ifndef HALF_ENABLE_F16C_INTRINSICS -/// Enable F16C intruction set intrinsics. -/// Defining this to 1 enables the use of [F16C compiler -/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between -/// half-precision and single-precision values which may result in improved performance. This will -/// not perform additional checks -/// for support of the F16C instruction set, so an appropriate target platform is required when -/// enabling this feature. -/// -/// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which -/// some compilers do on supporting platforms. -#define HALF_ENABLE_F16C_INTRINSICS __F16C__ -#endif - -#ifdef HALF_DOXYGEN_ONLY -/// Type for internal floating-point computations. -/// This can be predefined to a built-in floating-point type (`float`, `double` or `long double`) to -/// override the internal -/// half-precision implementation to use this type for computing arithmetic operations and -/// mathematical function (if available). -/// This can result in improved performance for arithmetic operators and mathematical functions but -/// might cause results to -/// deviate from the specified half-precision rounding mode and inhibits proper detection of -/// half-precision exceptions. -#define HALF_ARITHMETIC_TYPE (undefined) - -/// Enable internal exception flags. -/// Defining this to 1 causes operations on half-precision values to raise internal floating-point -/// exception flags according to -/// the IEEE 754 standard. These can then be cleared and checked with clearexcept(), testexcept(). -#define HALF_ERRHANDLING_FLAGS 0 - -/// Enable exception propagation to `errno`. -/// Defining this to 1 causes operations on half-precision values to propagate floating-point -/// exceptions to -/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. Specifically this will -/// propagate domain errors as -/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, overflow and underflow -/// errors as -/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact errors won't be -/// propagated. -#define HALF_ERRHANDLING_ERRNO 0 - -/// Enable exception propagation to built-in floating-point platform. -/// Defining this to 1 causes operations on half-precision values to propagate floating-point -/// exceptions to the built-in -/// single- and double-precision implementation's exception flags using the -/// [C++11 floating-point environment control](https://en.cppreference.com/w/cpp/numeric/fenv) from -/// ``. However, this -/// does not work in reverse and single- or double-precision exceptions will not raise the -/// corresponding half-precision -/// exception flags, nor will explicitly clearing flags clear the corresponding built-in flags. -#define HALF_ERRHANDLING_FENV 0 - -/// Throw C++ exception on domain errors. -/// Defining this to a string literal causes operations on half-precision values to throw a -/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified -/// message on domain errors. -#define HALF_ERRHANDLING_THROW_INVALID (undefined) - -/// Throw C++ exception on pole errors. -/// Defining this to a string literal causes operations on half-precision values to throw a -/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified -/// message on pole errors. -#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) - -/// Throw C++ exception on overflow errors. -/// Defining this to a string literal causes operations on half-precision values to throw a -/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) with the specified -/// message on overflows. -#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) - -/// Throw C++ exception on underflow errors. -/// Defining this to a string literal causes operations on half-precision values to throw a -/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) with the -/// specified message on underflows. -#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) - -/// Throw C++ exception on rounding errors. -/// Defining this to 1 causes operations on half-precision values to throw a -/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with the specified -/// message on general rounding errors. -#define HALF_ERRHANDLING_THROW_INEXACT (undefined) -#endif - -#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT -/// Raise INEXACT exception on overflow. -/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in -/// addition. -/// These will be raised after any possible handling of the underflow exception. -#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 -#endif - -#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT -/// Raise INEXACT exception on underflow. -/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions -/// in addition. -/// These will be raised after any possible handling of the underflow exception. -/// -/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be -/// raised *only* when the result -/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) -/// subnormal result. -#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 -#endif - -/// Default rounding mode. -/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s -/// and more precise types -/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic -/// operations and mathematical -/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes -/// using their respective -/// constants or the equivalent values of -/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): -/// -/// `std::float_round_style` | value | rounding -/// ---------------------------------|-------|------------------------- -/// `std::round_indeterminate` | -1 | fastest -/// `std::round_toward_zero` | 0 | toward zero -/// `std::round_to_nearest` | 1 | to nearest (default) -/// `std::round_toward_infinity` | 2 | toward positive infinity -/// `std::round_toward_neg_infinity` | 3 | toward negative infinity -/// -/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest -/// representable value. It can even -/// be set to -/// [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) -/// to synchronize -/// the rounding mode with that of the built-in single-precision implementation (which is likely -/// `std::round_to_nearest`, though). -#ifndef HALF_ROUND_STYLE -#define HALF_ROUND_STYLE 1 // = std::round_to_nearest -#endif - -/// Value signaling overflow. -/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value -/// signaling the overflow of an -/// operation, in particular it just evaluates to positive infinity. -/// -/// **See also:** Documentation for -/// [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) -#define HUGE_VALH std::numeric_limits::infinity() - -/// Fast half-precision fma function. -/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a -/// separate -/// half-precision multiplication followed by an addition, which is always the case. -/// -/// **See also:** Documentation for -/// [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) -#define FP_FAST_FMAH 1 - -/// Half rounding mode. -/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode -/// used for -/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). -/// -/// **See also:** Documentation for -/// [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) -#define HLF_ROUNDS HALF_ROUND_STYLE - -#ifndef FP_ILOGB0 -#define FP_ILOGB0 INT_MIN -#endif -#ifndef FP_ILOGBNAN -#define FP_ILOGBNAN INT_MAX -#endif -#ifndef FP_SUBNORMAL -#define FP_SUBNORMAL 0 -#endif -#ifndef FP_ZERO -#define FP_ZERO 1 -#endif -#ifndef FP_NAN -#define FP_NAN 2 -#endif -#ifndef FP_INFINITE -#define FP_INFINITE 3 -#endif -#ifndef FP_NORMAL -#define FP_NORMAL 4 -#endif - -#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) -#define FE_INVALID 0x10 -#define FE_DIVBYZERO 0x08 -#define FE_OVERFLOW 0x04 -#define FE_UNDERFLOW 0x02 -#define FE_INEXACT 0x01 -#define FE_ALL_EXCEPT (FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW | FE_INEXACT) -#endif - -/// Main namespace for half-precision functionality. -/// This namespace contains all the functionality provided by the library. -namespace half_float { -class half; - -#if HALF_ENABLE_CPP11_USER_LITERALS -/// Library-defined half-precision literals. -/// Import this namespace to enable half-precision floating-point literals: -/// ~~~~{.cpp} -/// using namespace half_float::literal; -/// half_float::half = 4.2_h; -/// ~~~~ -namespace literal { -half operator"" _h(long double); -} -#endif - -/// \internal -/// \brief Implementation details. -namespace detail { -#if HALF_ENABLE_CPP11_TYPE_TRAITS -/// Conditional type. -template -struct conditional : std::conditional -{ -}; - -/// Helper for tag dispatching. -template -struct bool_type : std::integral_constant -{ -}; -using std::false_type; -using std::true_type; - -/// Type traits for floating-point types. -template -struct is_float : std::is_floating_point -{ -}; -#else -/// Conditional type. -template -struct conditional -{ - typedef T type; -}; -template -struct conditional -{ - typedef F type; -}; - -/// Helper for tag dispatching. -template -struct bool_type -{ -}; -typedef bool_type true_type; -typedef bool_type false_type; - -/// Type traits for floating-point types. -template -struct is_float : false_type -{ -}; -template -struct is_float : is_float -{ -}; -template -struct is_float : is_float -{ -}; -template -struct is_float : is_float -{ -}; -template <> -struct is_float : true_type -{ -}; -template <> -struct is_float : true_type -{ -}; -template <> -struct is_float : true_type -{ -}; -#endif - -/// Type traits for floating-point bits. -template -struct bits -{ - typedef unsigned char type; -}; -template -struct bits : bits -{ -}; -template -struct bits : bits -{ -}; -template -struct bits : bits -{ -}; - -#if HALF_ENABLE_CPP11_CSTDINT -/// Unsigned integer of (at least) 16 bits width. -typedef std::uint_least16_t uint16; - -/// Fastest unsigned integer of (at least) 32 bits width. -typedef std::uint_fast32_t uint32; - -/// Fastest signed integer of (at least) 32 bits width. -typedef std::int_fast32_t int32; - -/// Unsigned integer of (at least) 32 bits width. -template <> -struct bits -{ - typedef std::uint_least32_t type; -}; - -/// Unsigned integer of (at least) 64 bits width. -template <> -struct bits -{ - typedef std::uint_least64_t type; -}; -#else -/// Unsigned integer of (at least) 16 bits width. -typedef unsigned short uint16; - -/// Fastest unsigned integer of (at least) 32 bits width. -typedef unsigned long uint32; - -/// Fastest unsigned integer of (at least) 32 bits width. -typedef long int32; - -/// Unsigned integer of (at least) 32 bits width. -template <> -struct bits - : conditional::digits >= 32, unsigned int, unsigned long> -{ -}; - -#if HALF_ENABLE_CPP11_LONG_LONG -/// Unsigned integer of (at least) 64 bits width. -template <> -struct bits : conditional::digits >= 64, - unsigned long, - unsigned long long> -{ -}; -#else -/// Unsigned integer of (at least) 64 bits width. -template <> -struct bits -{ - typedef unsigned long type; -}; -#endif -#endif - -#ifdef HALF_ARITHMETIC_TYPE -/// Type to use for arithmetic computations and mathematic functions internally. -typedef HALF_ARITHMETIC_TYPE internal_t; -#endif - -/// Tag type for binary construction. -struct binary_t -{ -}; - -/// Tag for binary construction. -HALF_CONSTEXPR_CONST binary_t binary = binary_t(); - -/// \name Implementation defined classification and arithmetic -/// \{ - -/// Check for infinity. -/// \tparam T argument type (builtin floating-point type) -/// \param arg value to query -/// \retval true if infinity -/// \retval false else -template -bool builtin_isinf(T arg) -{ -#if HALF_ENABLE_CPP11_CMATH - return std::isinf(arg); -#elif defined(_MSC_VER) - return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); -#else - return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); -#endif -} - -/// Check for NaN. -/// \tparam T argument type (builtin floating-point type) -/// \param arg value to query -/// \retval true if not a number -/// \retval false else -template -bool builtin_isnan(T arg) -{ -#if HALF_ENABLE_CPP11_CMATH - return std::isnan(arg); -#elif defined(_MSC_VER) - return ::_isnan(static_cast(arg)) != 0; -#else - return arg != arg; -#endif -} - -/// Check sign. -/// \tparam T argument type (builtin floating-point type) -/// \param arg value to query -/// \retval true if signbit set -/// \retval false else -template -bool builtin_signbit(T arg) -{ -#if HALF_ENABLE_CPP11_CMATH - return std::signbit(arg); -#else - return arg < T() || (arg == T() && T(1) / arg < T()); -#endif -} - -/// Platform-independent sign mask. -/// \param arg integer value in two's complement -/// \retval -1 if \a arg negative -/// \retval 0 if \a arg positive -inline uint32 sign_mask(uint32 arg) -{ - static const int N = std::numeric_limits::digits - 1; -#if HALF_TWOS_COMPLEMENT_INT - return static_cast(arg) >> N; -#else - return -((arg >> N) & 1); -#endif -} - -/// Platform-independent arithmetic right shift. -/// \param arg integer value in two's complement -/// \param i shift amount (at most 31) -/// \return \a arg right shifted for \a i bits with possible sign extension -inline uint32 arithmetic_shift(uint32 arg, int i) -{ -#if HALF_TWOS_COMPLEMENT_INT - return static_cast(arg) >> i; -#else - return static_cast(arg) / (static_cast(1) << i) - - ((arg >> (std::numeric_limits::digits - 1)) & 1); -#endif -} - -/// \} -/// \name Error handling -/// \{ - -/// Internal exception flags. -/// \return reference to global exception flags -inline int& errflags() -{ - HALF_THREAD_LOCAL int flags = 0; - return flags; -} - -/// Raise floating-point exception. -/// \param flags exceptions to raise -/// \param cond condition to raise exceptions for -inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) -{ -#if HALF_ERRHANDLING - if(!cond) - return; -#if HALF_ERRHANDLING_FLAGS - errflags() |= flags; -#endif -#if HALF_ERRHANDLING_ERRNO - if(flags & FE_INVALID) - errno = EDOM; - else if(flags & (FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW)) - errno = ERANGE; -#endif -#if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV - std::feraiseexcept(flags); -#endif -#ifdef HALF_ERRHANDLING_THROW_INVALID - if(flags & FE_INVALID) - throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); -#endif -#ifdef HALF_ERRHANDLING_THROW_DIVBYZERO - if(flags & FE_DIVBYZERO) - throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); -#endif -#ifdef HALF_ERRHANDLING_THROW_OVERFLOW - if(flags & FE_OVERFLOW) - throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); -#endif -#ifdef HALF_ERRHANDLING_THROW_UNDERFLOW - if(flags & FE_UNDERFLOW) - throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); -#endif -#ifdef HALF_ERRHANDLING_THROW_INEXACT - if(flags & FE_INEXACT) - throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); -#endif -#if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT - if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) - raise(FE_INEXACT); -#endif -#if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT - if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) - raise(FE_INEXACT); -#endif -#endif -} - -/// Check and signal for any NaN. -/// \param x first half-precision value to check -/// \param y second half-precision value to check -/// \retval true if either \a x or \a y is NaN -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) -{ -#if HALF_ERRHANDLING - raise(FE_INVALID, (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00); -#endif - return (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00; -} - -/// Signal and silence signaling NaN. -/// \param nan half-precision NaN value -/// \return quiet NaN -/// \exception FE_INVALID if \a nan is signaling NaN -inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) -{ -#if HALF_ERRHANDLING - raise(FE_INVALID, !(nan & 0x200)); -#endif - return nan | 0x200; -} - -/// Signal and silence signaling NaNs. -/// \param x first half-precision value to check -/// \param y second half-precision value to check -/// \return quiet NaN -/// \exception FE_INVALID if \a x or \a y is signaling NaN -inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y) -{ -#if HALF_ERRHANDLING - raise(FE_INVALID, - ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200))); -#endif - return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : (y | 0x200); -} - -/// Signal and silence signaling NaNs. -/// \param x first half-precision value to check -/// \param y second half-precision value to check -/// \param z third half-precision value to check -/// \return quiet NaN -/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN -inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) -{ -#if HALF_ERRHANDLING - raise(FE_INVALID, - ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || - ((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); -#endif - return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) - : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) : (z | 0x200); -} - -/// Select value or signaling NaN. -/// \param x preferred half-precision value -/// \param y ignored half-precision value except for signaling NaN -/// \return \a y if signaling NaN, \a x otherwise -/// \exception FE_INVALID if \a y is signaling NaN -inline HALF_CONSTEXPR_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) -{ -#if HALF_ERRHANDLING - return (((y & 0x7FFF) > 0x7C00) && !(y & 0x200)) ? signal(y) : x; -#else - return x; -#endif -} - -/// Raise domain error and return NaN. -/// return quiet NaN -/// \exception FE_INVALID -inline HALF_CONSTEXPR_NOERR unsigned int invalid() -{ -#if HALF_ERRHANDLING - raise(FE_INVALID); -#endif - return 0x7FFF; -} - -/// Raise pole error and return infinity. -/// \param sign half-precision value with sign bit only -/// \return half-precision infinity with sign of \a sign -/// \exception FE_DIVBYZERO -inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) -{ -#if HALF_ERRHANDLING - raise(FE_DIVBYZERO); -#endif - return sign | 0x7C00; -} - -/// Check value for underflow. -/// \param arg non-zero half-precision value to check -/// \return \a arg -/// \exception FE_UNDERFLOW if arg is subnormal -inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) -{ -#if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT - raise(FE_UNDERFLOW, !(arg & 0x7C00)); -#endif - return arg; -} - -/// \} -/// \name Conversion and rounding -/// \{ - -/// Half-precision overflow. -/// \tparam R rounding mode to use -/// \param sign half-precision value with sign bit only -/// \return rounded overflowing half-precision value -/// \exception FE_OVERFLOW -template -HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) -{ -#if HALF_ERRHANDLING - raise(FE_OVERFLOW); -#endif - return (R == std::round_toward_infinity) - ? (sign + 0x7C00 - (sign >> 15)) - : (R == std::round_toward_neg_infinity) - ? (sign + 0x7BFF + (sign >> 15)) - : (R == std::round_toward_zero) ? (sign | 0x7BFF) : (sign | 0x7C00); -} - -/// Half-precision underflow. -/// \tparam R rounding mode to use -/// \param sign half-precision value with sign bit only -/// \return rounded underflowing half-precision value -/// \exception FE_UNDERFLOW -template -HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) -{ -#if HALF_ERRHANDLING - raise(FE_UNDERFLOW); -#endif - return (R == std::round_toward_infinity) - ? (sign + 1 - (sign >> 15)) - : (R == std::round_toward_neg_infinity) ? (sign + (sign >> 15)) : sign; -} - -/// Round half-precision number. -/// \tparam R rounding mode to use -/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results -/// \param value finite half-precision number to round -/// \param g guard bit (most significant discarded bit) -/// \param s sticky bit (or of all but the most significant discarded bits) -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded or \a I is `true` -template -HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) -{ -#if HALF_ERRHANDLING - value += (R == std::round_to_nearest) - ? (g & (s | value)) - : (R == std::round_toward_infinity) - ? (~(value >> 15) & (g | s)) - : (R == std::round_toward_neg_infinity) ? ((value >> 15) & (g | s)) : 0; - if((value & 0x7C00) == 0x7C00) - raise(FE_OVERFLOW); - else if(value & 0x7C00) - raise(FE_INEXACT, I || (g | s) != 0); - else - raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g | s) != 0); - return value; -#else - return (R == std::round_to_nearest) - ? (value + (g & (s | value))) - : (R == std::round_toward_infinity) - ? (value + (~(value >> 15) & (g | s))) - : (R == std::round_toward_neg_infinity) ? (value + ((value >> 15) & (g | s))) - : value; -#endif -} - -/// Round half-precision number to nearest integer value. -/// \tparam R rounding mode to use -/// \tparam E `true` for round to even, `false` for round away from zero -/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it -/// \param value half-precision value to round -/// \return half-precision bits for nearest integral value -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded and \a I is `true` -template -unsigned int integral(unsigned int value) -{ - unsigned int abs = value & 0x7FFF; - if(abs < 0x3C00) - { - raise(FE_INEXACT, I); - return ((R == std::round_to_nearest) - ? (0x3C00 & -static_cast(abs >= (0x3800 + E))) - : (R == std::round_toward_infinity) - ? (0x3C00 & -(~(value >> 15) & (abs != 0))) - : (R == std::round_toward_neg_infinity) - ? (0x3C00 & -static_cast(value > 0x8000)) - : 0) | - (value & 0x8000); - } - if(abs >= 0x6400) - return (abs > 0x7C00) ? signal(value) : value; - unsigned int exp = 25 - (abs >> 10), mask = (1 << exp) - 1; - raise(FE_INEXACT, I && (value & mask)); - return (((R == std::round_to_nearest) - ? ((1 << (exp - 1)) - (~(value >> exp) & E)) - : (R == std::round_toward_infinity) - ? (mask & ((value >> 15) - 1)) - : (R == std::round_toward_neg_infinity) ? (mask & -(value >> 15)) : 0) + - value) & - ~mask; -} - -/// Convert fixed point to half-precision floating-point. -/// \tparam R rounding mode to use -/// \tparam F number of fractional bits (at least 11) -/// \tparam S `true` for signed, `false` for unsigned -/// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F -/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results -/// \param m mantissa in Q1.F fixed point format -/// \param exp exponent -/// \param sign half-precision value with sign bit only -/// \param s sticky bit (or of all but the most significant already discarded bits) -/// \return value converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded or \a I is `true` -template -unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) -{ - if(S) - { - uint32 msign = sign_mask(m); - m = (m ^ msign) - msign; - sign = msign & 0x8000; - } - if(N) - for(; m < (static_cast(1) << F) && exp; m <<= 1, --exp) - ; - else if(exp < 0) - return rounded(sign + (m >> (F - 10 - exp)), - (m >> (F - 11 - exp)) & 1, - s | ((m & ((static_cast(1) << (F - 11 - exp)) - 1)) != 0)); - return rounded(sign + (exp << 10) + (m >> (F - 10)), - (m >> (F - 11)) & 1, - s | ((m & ((static_cast(1) << (F - 11)) - 1)) != 0)); -} - -/// Convert IEEE single-precision to half-precision. -/// Credit for this goes to [Jeroen van der -/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). -/// \tparam R rounding mode to use -/// \param value single-precision value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int float2half_impl(float value, true_type) -{ -#if HALF_ENABLE_F16C_INTRINSICS - return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), - (R == std::round_to_nearest) - ? _MM_FROUND_TO_NEAREST_INT - : (R == std::round_toward_zero) - ? _MM_FROUND_TO_ZERO - : (R == std::round_toward_infinity) - ? _MM_FROUND_TO_POS_INF - : (R == std::round_toward_neg_infinity) - ? _MM_FROUND_TO_NEG_INF - : _MM_FROUND_CUR_DIRECTION)); -#else - bits::type fbits; - std::memcpy(&fbits, &value, sizeof(float)); -#if 1 - unsigned int sign = (fbits >> 16) & 0x8000; - fbits &= 0x7FFFFFFF; - if(fbits >= 0x7F800000) - return sign | 0x7C00 | ((fbits > 0x7F800000) ? (0x200 | ((fbits >> 13) & 0x3FF)) : 0); - if(fbits >= 0x47800000) - return overflow(sign); - if(fbits >= 0x38800000) - return rounded(sign | (((fbits >> 23) - 112) << 10) | ((fbits >> 13) & 0x3FF), - (fbits >> 12) & 1, - (fbits & 0xFFF) != 0); - if(fbits >= 0x33000000) - { - int i = 125 - (fbits >> 23); - fbits = (fbits & 0x7FFFFF) | 0x800000; - return rounded(sign | (fbits >> (i + 1)), - (fbits >> i) & 1, - (fbits & ((static_cast(1) << i) - 1)) != 0); - } - if(fbits != 0) - return underflow(sign); - return sign; -#else - static const uint16 base_table[512] = { - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, - 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, - 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, - 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, - 0x7BFF, 0x7BFF, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, - 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, - 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, - 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, - 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, - 0xF000, 0xF400, 0xF800, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, - 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00}; - static const unsigned char shift_table[256] = { - 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, - 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, - 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, - 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, - 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, - 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, - 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, - 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13}; - int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; - fbits &= 0x7FFFFF; - uint32 m = (fbits | ((exp != 0) << 23)) & -static_cast(exp != 0xFF); - return rounded(base_table[sexp] + (fbits >> i), - (m >> (i - 1)) & 1, - (((static_cast(1) << (i - 1)) - 1) & m) != 0); -#endif -#endif -} - -/// Convert IEEE double-precision to half-precision. -/// \tparam R rounding mode to use -/// \param value double-precision value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int float2half_impl(double value, true_type) -{ -#if HALF_ENABLE_F16C_INTRINSICS - if(R == std::round_indeterminate) - return _mm_cvtsi128_si32( - _mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); -#endif - bits::type dbits; - std::memcpy(&dbits, &value, sizeof(double)); - uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; - unsigned int sign = (hi >> 16) & 0x8000; - hi &= 0x7FFFFFFF; - if(hi >= 0x7FF00000) - return sign | 0x7C00 | ((dbits & 0xFFFFFFFFFFFFF) ? (0x200 | ((hi >> 10) & 0x3FF)) : 0); - if(hi >= 0x40F00000) - return overflow(sign); - if(hi >= 0x3F100000) - return rounded(sign | (((hi >> 20) - 1008) << 10) | ((hi >> 10) & 0x3FF), - (hi >> 9) & 1, - ((hi & 0x1FF) | lo) != 0); - if(hi >= 0x3E600000) - { - int i = 1018 - (hi >> 20); - hi = (hi & 0xFFFFF) | 0x100000; - return rounded(sign | (hi >> (i + 1)), - (hi >> i) & 1, - ((hi & ((static_cast(1) << i) - 1)) | lo) != 0); - } - if((hi | lo) != 0) - return underflow(sign); - return sign; -} - -/// Convert non-IEEE floating-point to half-precision. -/// \tparam R rounding mode to use -/// \tparam T source type (builtin floating-point type) -/// \param value floating-point value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int float2half_impl(T value, ...) -{ - unsigned int hbits = static_cast(builtin_signbit(value)) << 15; - if(value == T()) - return hbits; - if(builtin_isnan(value)) - return hbits | 0x7FFF; - if(builtin_isinf(value)) - return hbits | 0x7C00; - int exp; - std::frexp(value, &exp); - if(exp > 16) - return overflow(hbits); - if(exp < -13) - value = std::ldexp(value, 25); - else - { - value = std::ldexp(value, 12 - exp); - hbits |= ((exp + 13) << 10); - } - T ival, frac = std::modf(value, &ival); - int m = std::abs(static_cast(ival)); - return rounded(hbits + (m >> 1), m & 1, frac != T()); -} - -/// Convert floating-point to half-precision. -/// \tparam R rounding mode to use -/// \tparam T source type (builtin floating-point type) -/// \param value floating-point value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int float2half(T value) -{ - return float2half_impl(value, - bool_type < std::numeric_limits::is_iec559 && - sizeof(typename bits::type) == sizeof(T) > ()); -} - -/// Convert integer to half-precision floating-point. -/// \tparam R rounding mode to use -/// \tparam T type to convert (builtin integer type) -/// \param value integral value to convert -/// \return rounded half-precision value -/// \exception FE_OVERFLOW on overflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int int2half(T value) -{ - unsigned int bits = static_cast(value < 0) << 15; - if(!value) - return bits; - if(bits) - value = -value; - if(value > 0xFFFF) - return overflow(bits); - unsigned int m = static_cast(value), exp = 24; - for(; m < 0x400; m <<= 1, --exp) - ; - for(; m > 0x7FF; m >>= 1, ++exp) - ; - bits |= (exp << 10) + m; - return (exp > 24) ? rounded( - bits, (value >> (exp - 25)) & 1, (((1 << (exp - 25)) - 1) & value) != 0) - : bits; -} - -/// Convert half-precision to IEEE single-precision. -/// Credit for this goes to [Jeroen van der -/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). -/// \param value half-precision value to convert -/// \return single-precision value -inline float half2float_impl(unsigned int value, float, true_type) -{ -#if HALF_ENABLE_F16C_INTRINSICS - return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); -#else -#if 0 - bits::type fbits = static_cast::type>(value&0x8000) << 16; - int abs = value & 0x7FFF; - if(abs) - { - fbits |= 0x38000000 << static_cast(abs>=0x7C00); - for(; abs<0x400; abs<<=1,fbits-=0x800000) ; - fbits += static_cast::type>(abs) << 13; - } -#else - static const bits::type mantissa_table[2048] = { - 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, - 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, - 0x35600000, 0x35700000, 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, - 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, - 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, - 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, - 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, 0x36400000, - 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, - 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, - 0x367C0000, 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, - 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, - 0x369A0000, 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, - 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, - 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, 0x36C00000, 0x36C20000, - 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, - 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, - 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, - 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, - 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, - 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, - 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, 0x37100000, 0x37110000, 0x37120000, - 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, - 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, - 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, - 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, - 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, - 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, - 0x373D0000, 0x373E0000, 0x373F0000, 0x37400000, 0x37410000, 0x37420000, 0x37430000, - 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, - 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, - 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, - 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, - 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, - 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, - 0x376E0000, 0x376F0000, 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, - 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, - 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, 0x37810000, - 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, - 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, 0x37880000, - 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, - 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, - 0x378F8000, 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, - 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, - 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, 0x37990000, 0x37998000, - 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, - 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, - 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, - 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, - 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, - 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, - 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, - 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, - 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, - 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, - 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, 0x37C00000, - 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, - 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, - 0x37C78000, 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, - 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, - 0x37CE8000, 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, - 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, - 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, 0x37D80000, 0x37D88000, - 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, - 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, - 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, - 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, - 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, - 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, - 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, 0x37F00000, 0x37F08000, 0x37F10000, - 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, - 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, - 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, - 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, - 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, - 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, - 0x38034000, 0x38038000, 0x3803C000, 0x38040000, 0x38044000, 0x38048000, 0x3804C000, - 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, - 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, - 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, - 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, - 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, - 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, - 0x380F8000, 0x380FC000, 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, - 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, - 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, 0x38148000, - 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, - 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, 0x38180000, - 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, - 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, - 0x381BC000, 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, - 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, - 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, 0x38208000, 0x3820C000, - 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, - 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, 0x38240000, 0x38244000, - 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, - 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, - 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, - 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, - 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, - 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, - 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, - 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, - 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, 0x38340000, - 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, - 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, - 0x3837C000, 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, - 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, - 0x383B4000, 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, - 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, - 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, 0x38400000, 0x38404000, - 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, - 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, - 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, - 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, - 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, - 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, - 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, 0x384C0000, 0x384C4000, 0x384C8000, - 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, - 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, - 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, - 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, - 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, - 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, - 0x38574000, 0x38578000, 0x3857C000, 0x38580000, 0x38584000, 0x38588000, 0x3858C000, - 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, - 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, - 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, - 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, - 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, - 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, - 0x38638000, 0x3863C000, 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, - 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, - 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, 0x38688000, - 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, - 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, 0x386C0000, - 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, - 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, - 0x386FC000, 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, - 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, - 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, 0x38748000, 0x3874C000, - 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, - 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, 0x38780000, 0x38784000, - 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, - 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, - 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, - 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, - 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, - 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, - 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, - 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, - 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, 0x38040000, - 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, - 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, - 0x3805E000, 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, - 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, - 0x3807A000, 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, - 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, - 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, 0x380A0000, 0x380A2000, - 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, - 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, - 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, - 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, - 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, - 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, - 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, 0x38100000, 0x38102000, 0x38104000, - 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, - 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, - 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, - 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, - 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, - 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, - 0x3815A000, 0x3815C000, 0x3815E000, 0x38160000, 0x38162000, 0x38164000, 0x38166000, - 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, - 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, - 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, - 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, - 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, - 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, - 0x381BC000, 0x381BE000, 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, - 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, - 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, 0x381E4000, - 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, - 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, 0x38200000, - 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, - 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, - 0x3821E000, 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, - 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, - 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, 0x38244000, 0x38246000, - 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, - 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, - 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, - 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, - 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, - 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, - 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, - 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, - 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, - 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, - 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, 0x382E0000, - 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, - 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, - 0x382FE000, 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, - 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, - 0x3831A000, 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, - 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, - 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, 0x38340000, 0x38342000, - 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, - 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, - 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, - 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, - 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, - 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, - 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, 0x383A0000, 0x383A2000, 0x383A4000, - 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, - 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, - 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, - 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, - 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, - 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, - 0x383FA000, 0x383FC000, 0x383FE000, 0x38400000, 0x38402000, 0x38404000, 0x38406000, - 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, - 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, - 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, - 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, - 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, - 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, - 0x3845C000, 0x3845E000, 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, - 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, - 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, 0x38484000, - 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, - 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, 0x384A0000, - 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, - 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, - 0x384BE000, 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, - 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, - 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, - 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, - 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, - 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, - 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, - 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, - 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, - 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, - 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, - 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, - 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, - 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, 0x38580000, - 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, - 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, - 0x3859E000, 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, - 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, - 0x385BA000, 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, - 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, - 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, 0x385E0000, 0x385E2000, - 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, - 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, - 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, - 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, - 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, - 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, - 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, 0x38640000, 0x38642000, 0x38644000, - 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, - 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, - 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, - 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, - 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, - 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, - 0x3869A000, 0x3869C000, 0x3869E000, 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, - 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, - 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, - 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, - 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, - 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, - 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, - 0x386FC000, 0x386FE000, 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, - 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, - 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, 0x38724000, - 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, - 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, 0x38740000, - 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, - 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, - 0x3875E000, 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, - 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, - 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, 0x38784000, 0x38786000, - 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, - 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, - 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, - 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, - 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, - 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, - 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, - 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, - 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000}; - static const bits::type exponent_table[64] = { - 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, - 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, - 0x07000000, 0x07800000, 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, - 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, - 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, - 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, - 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, 0x88000000, - 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, - 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, - 0xC7800000}; - static const unsigned short offset_table[64] = { - 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, - 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, - 1024, 1024, 1024, 1024, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, - 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, - 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; - bits::type fbits = - mantissa_table[offset_table[value >> 10] + (value & 0x3FF)] + exponent_table[value >> 10]; -#endif - float out; - std::memcpy(&out, &fbits, sizeof(float)); - return out; -#endif -} - -/// Convert half-precision to IEEE double-precision. -/// \param value half-precision value to convert -/// \return double-precision value -inline double half2float_impl(unsigned int value, double, true_type) -{ -#if HALF_ENABLE_F16C_INTRINSICS - return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); -#else - uint32 hi = static_cast(value & 0x8000) << 16; - unsigned int abs = value & 0x7FFF; - if(abs) - { - hi |= 0x3F000000 << static_cast(abs >= 0x7C00); - for(; abs < 0x400; abs <<= 1, hi -= 0x100000) - ; - hi += static_cast(abs) << 10; - } - bits::type dbits = static_cast::type>(hi) << 32; - double out; - std::memcpy(&out, &dbits, sizeof(double)); - return out; -#endif -} - -/// Convert half-precision to non-IEEE floating-point. -/// \tparam T type to convert to (builtin integer type) -/// \param value half-precision value to convert -/// \return floating-point value -template -T half2float_impl(unsigned int value, T, ...) -{ - T out; - unsigned int abs = value & 0x7FFF; - if(abs > 0x7C00) - out = - (std::numeric_limits::has_signaling_NaN && !(abs & 0x200)) - ? std::numeric_limits::signaling_NaN() - : std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); - else if(abs == 0x7C00) - out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() - : std::numeric_limits::max(); - else if(abs > 0x3FF) - out = std::ldexp(static_cast((abs & 0x3FF) | 0x400), (abs >> 10) - 25); - else - out = std::ldexp(static_cast(abs), -24); - return (value & 0x8000) ? -out : out; -} - -/// Convert half-precision to floating-point. -/// \tparam T type to convert to (builtin integer type) -/// \param value half-precision value to convert -/// \return floating-point value -template -T half2float(unsigned int value) -{ - return half2float_impl(value, - T(), - bool_type < std::numeric_limits::is_iec559 && - sizeof(typename bits::type) == sizeof(T) > ()); -} - -/// Convert half-precision floating-point to integer. -/// \tparam R rounding mode to use -/// \tparam E `true` for round to even, `false` for round away from zero -/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it -/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding -/// any implicit sign bits) -/// \param value half-precision value to convert -/// \return rounded integer value -/// \exception FE_INVALID if value is not representable in type \a T -/// \exception FE_INEXACT if value had to be rounded and \a I is `true` -template -T half2int(unsigned int value) -{ - unsigned int abs = value & 0x7FFF; - if(abs >= 0x7C00) - { - raise(FE_INVALID); - return (value & 0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); - } - if(abs < 0x3800) - { - raise(FE_INEXACT, I); - return (R == std::round_toward_infinity) - ? T(~(value >> 15) & (abs != 0)) - : (R == std::round_toward_neg_infinity) ? -T(value > 0x8000) : T(); - } - int exp = 25 - (abs >> 10); - unsigned int m = (value & 0x3FF) | 0x400; - int32 i = static_cast( - (exp <= 0) - ? (m << -exp) - : ((m + ((R == std::round_to_nearest) ? ((1 << (exp - 1)) - (~(m >> exp) & E)) - : (R == std::round_toward_infinity) - ? (((1 << exp) - 1) & ((value >> 15) - 1)) - : (R == std::round_toward_neg_infinity) - ? (((1 << exp) - 1) & -(value >> 15)) - : 0)) >> - exp)); - if((!std::numeric_limits::is_signed && (value & 0x8000)) || - (std::numeric_limits::digits < 16 && - ((value & 0x8000) ? (-i < std::numeric_limits::min()) - : (i > std::numeric_limits::max())))) - raise(FE_INVALID); - else if(I && exp > 0 && (m & ((1 << exp) - 1))) - raise(FE_INEXACT); - return static_cast((value & 0x8000) ? -i : i); -} - -/// \} -/// \name Mathematics -/// \{ - -/// upper part of 64-bit multiplication. -/// \tparam R rounding mode to use -/// \param x first factor -/// \param y second factor -/// \return upper 32 bit of \a x * \a y -template -uint32 mulhi(uint32 x, uint32 y) -{ - uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), - c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); - return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + - ((R == std::round_to_nearest) - ? ((c >> 15) & 1) - : (R == std::round_toward_infinity) ? ((c & 0xFFFF) != 0) : 0); -} - -/// 64-bit multiplication. -/// \param x first factor -/// \param y second factor -/// \return upper 32 bit of \a x * \a y rounded to nearest -inline uint32 multiply64(uint32 x, uint32 y) -{ -#if HALF_ENABLE_CPP11_LONG_LONG - return static_cast( - (static_cast(x) * static_cast(y) + 0x80000000) >> - 32); -#else - return mulhi(x, y); -#endif -} - -/// 64-bit division. -/// \param x upper 32 bit of dividend -/// \param y divisor -/// \param s variable to store sticky bit for rounding -/// \return (\a x << 32) / \a y -inline uint32 divide64(uint32 x, uint32 y, int& s) -{ -#if HALF_ENABLE_CPP11_LONG_LONG - unsigned long long xx = static_cast(x) << 32; - return s = (xx % y != 0), static_cast(xx / y); -#else - y >>= 1; - uint32 rem = x, div = 0; - for(unsigned int i = 0; i < 32; ++i) - { - div <<= 1; - if(rem >= y) - { - rem -= y; - div |= 1; - } - rem <<= 1; - } - return s = rem > 1, div; -#endif -} - -/// Half precision positive modulus. -/// \tparam Q `true` to compute full quotient, `false` else -/// \tparam R `true` to compute signed remainder, `false` for positive remainder -/// \param x first operand as positive finite half-precision value -/// \param y second operand as positive finite half-precision value -/// \param quo adress to store quotient at, `nullptr` if \a Q `false` -/// \return modulus of \a x / \a y -template -unsigned int mod(unsigned int x, unsigned int y, int* quo = NULL) -{ - unsigned int q = 0; - if(x > y) - { - int absx = x, absy = y, expx = 0, expy = 0; - for(; absx < 0x400; absx <<= 1, --expx) - ; - for(; absy < 0x400; absy <<= 1, --expy) - ; - expx += absx >> 10; - expy += absy >> 10; - int mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; - for(int d = expx - expy; d; --d) - { - if(!Q && mx == my) - return 0; - if(mx >= my) - { - mx -= my; - q += Q; - } - mx <<= 1; - q <<= static_cast(Q); - } - if(!Q && mx == my) - return 0; - if(mx >= my) - { - mx -= my; - ++q; - } - if(Q) - { - q &= (1 << (std::numeric_limits::digits - 1)) - 1; - if(!mx) - return *quo = q, 0; - } - for(; mx < 0x400; mx <<= 1, --expy) - ; - x = (expy > 0) ? ((expy << 10) | (mx & 0x3FF)) : (mx >> (1 - expy)); - } - if(R) - { - unsigned int a, b; - if(y < 0x800) - { - a = (x < 0x400) ? (x << 1) : (x + 0x400); - b = y; - } - else - { - a = x; - b = y - 0x400; - } - if(a > b || (a == b && (q & 1))) - { - int exp = (y >> 10) + (y <= 0x3FF), d = exp - (x >> 10) - (x <= 0x3FF); - int m = (((y & 0x3FF) | ((y > 0x3FF) << 10)) << 1) - - (((x & 0x3FF) | ((x > 0x3FF) << 10)) << (1 - d)); - for(; m < 0x800 && exp > 1; m <<= 1, --exp) - ; - x = 0x8000 + ((exp - 1) << 10) + (m >> 1); - q += Q; - } - } - if(Q) - *quo = q; - return x; -} - -/// Fixed point square root. -/// \tparam F number of fractional bits -/// \param r radicand in Q1.F fixed point format -/// \param exp exponent -/// \return square root as Q1.F/2 -template -uint32 sqrt(uint32& r, int& exp) -{ - int i = exp & 1; - r <<= i; - exp = (exp - i) / 2; - uint32 m = 0; - for(uint32 bit = static_cast(1) << F; bit; bit >>= 2) - { - if(r < m + bit) - m >>= 1; - else - { - r -= m + bit; - m = (m >> 1) + bit; - } - } - return m; -} - -/// Fixed point binary exponential. -/// This uses the BKM algorithm in E-mode. -/// \param m exponent in [0,1) as Q0.31 -/// \param n number of iterations (at most 32) -/// \return 2 ^ \a m as Q1.31 -inline uint32 exp2(uint32 m, unsigned int n = 32) -{ - static const uint32 logs[] = { - 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, - 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, - 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, - 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, - 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; - if(!m) - return 0x80000000; - uint32 mx = 0x80000000, my = 0; - for(unsigned int i = 1; i < n; ++i) - { - uint32 mz = my + logs[i]; - if(mz <= m) - { - my = mz; - mx += mx >> i; - } - } - return mx; -} - -/// Fixed point binary logarithm. -/// This uses the BKM algorithm in L-mode. -/// \param m mantissa in [1,2) as Q1.30 -/// \param n number of iterations (at most 32) -/// \return log2(\a m) as Q0.31 -inline uint32 log2(uint32 m, unsigned int n = 32) -{ - static const uint32 logs[] = { - 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, - 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, - 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, - 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, - 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; - if(m == 0x40000000) - return 0; - uint32 mx = 0x40000000, my = 0; - for(unsigned int i = 1; i < n; ++i) - { - uint32 mz = mx + (mx >> i); - if(mz <= m) - { - mx = mz; - my += logs[i]; - } - } - return my; -} - -/// Fixed point sine and cosine. -/// This uses the CORDIC algorithm in rotation mode. -/// \param mz angle in [-pi/2,pi/2] as Q1.30 -/// \param n number of iterations (at most 31) -/// \return sine and cosine of \a mz as Q1.30 -inline std::pair sincos(uint32 mz, unsigned int n = 31) -{ - static const uint32 angles[] = { - 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, - 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, - 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, - 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, - 0x00000004, 0x00000002, 0x00000001}; - uint32 mx = 0x26DD3B6A, my = 0; - for(unsigned int i = 0; i < n; ++i) - { - uint32 sign = sign_mask(mz); - uint32 tx = mx - (arithmetic_shift(my, i) ^ sign) + sign; - uint32 ty = my + (arithmetic_shift(mx, i) ^ sign) - sign; - mx = tx; - my = ty; - mz -= (angles[i] ^ sign) - sign; - } - return std::make_pair(my, mx); -} - -/// Fixed point arc tangent. -/// This uses the CORDIC algorithm in vectoring mode. -/// \param my y coordinate as Q0.30 -/// \param mx x coordinate as Q0.30 -/// \param n number of iterations (at most 31) -/// \return arc tangent of \a my / \a mx as Q1.30 -inline uint32 atan2(uint32 my, uint32 mx, unsigned int n = 31) -{ - static const uint32 angles[] = { - 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, - 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, - 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, - 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, - 0x00000004, 0x00000002, 0x00000001}; - uint32 mz = 0; - for(unsigned int i = 0; i < n; ++i) - { - uint32 sign = sign_mask(my); - uint32 tx = mx + (arithmetic_shift(my, i) ^ sign) - sign; - uint32 ty = my - (arithmetic_shift(mx, i) ^ sign) + sign; - mx = tx; - my = ty; - mz += (angles[i] ^ sign) - sign; - } - return mz; -} - -/// Reduce argument for trigonometric functions. -/// \param abs half-precision floating-point value -/// \param k value to take quarter period -/// \return \a abs reduced to [-pi/4,pi/4] as Q0.30 -inline uint32 angle_arg(unsigned int abs, int& k) -{ - uint32 m = (abs & 0x3FF) | ((abs > 0x3FF) << 10); - int exp = (abs >> 10) + (abs <= 0x3FF) - 15; - if(abs < 0x3A48) - return k = 0, m << (exp + 20); -#if HALF_ENABLE_CPP11_LONG_LONG - unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL << (62 - exp)) - 1, - yi = (y + (mask >> 1)) & ~mask, f = y - yi; - uint32 sign = -static_cast(f >> 63); - k = static_cast(yi >> (62 - exp)); - return (multiply64(static_cast((sign ? -f : f) >> (31 - exp)), 0xC90FDAA2) ^ sign) - - sign; -#else - uint32 yh = m * 0xA2F98 + mulhi(m, 0x36E4E442), - yl = (m * 0x36E4E442) & 0xFFFFFFFF; - uint32 mask = (static_cast(1) << (30 - exp)) - 1, yi = (yh + (mask >> 1)) & ~mask, - sign = -static_cast(yi > yh); - k = static_cast(yi >> (30 - exp)); - uint32 fh = (yh ^ sign) + (yi ^ ~sign) - ~sign, fl = (yl ^ sign) - sign; - return (multiply64((exp > -1) - ? (((fh << (1 + exp)) & 0xFFFFFFFF) | ((fl & 0xFFFFFFFF) >> (31 - exp))) - : fh, - 0xC90FDAA2) ^ - sign) - - sign; -#endif -} - -/// Get arguments for atan2 function. -/// \param abs half-precision floating-point value -/// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 -inline std::pair atan2_args(unsigned int abs) -{ - int exp = -15; - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - uint32 my = ((abs & 0x3FF) | 0x400) << 5, r = my * my; - int rexp = 2 * exp; - r = 0x40000000 - - ((rexp > -31) ? ((r >> -rexp) | ((r & ((static_cast(1) << -rexp) - 1)) != 0)) : 1); - for(rexp = 0; r < 0x40000000; r <<= 1, --rexp) - ; - uint32 mx = sqrt<30>(r, rexp); - int d = exp - rexp; - if(d < 0) - return std::make_pair((d < -14) ? ((my >> (-d - 14)) + ((my >> (-d - 15)) & 1)) - : (my << (14 + d)), - (mx << 14) + (r << 13) / mx); - if(d > 0) - return std::make_pair(my << 14, - (d > 14) - ? ((mx >> (d - 14)) + ((mx >> (d - 15)) & 1)) - : ((d == 14) ? mx : ((mx << (14 - d)) + (r << (13 - d)) / mx))); - return std::make_pair(my << 13, (mx << 13) + (r << 12) / mx); -} - -/// Get exponentials for hyperbolic computation -/// \param abs half-precision floating-point value -/// \param exp variable to take unbiased exponent of larger result -/// \param n number of BKM iterations (at most 32) -/// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent -inline std::pair hyperbolic_args(unsigned int abs, int& exp, unsigned int n = 32) -{ - uint32 mx = detail::multiply64(static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, - 0xB8AA3B29), - my; - int e = (abs >> 10) + (abs <= 0x3FF); - if(e < 14) - { - exp = 0; - mx >>= 14 - e; - } - else - { - exp = mx >> (45 - e); - mx = (mx << (e - 14)) & 0x7FFFFFFF; - } - mx = exp2(mx, n); - int d = exp << 1, s; - if(mx > 0x80000000) - { - my = divide64(0x80000000, mx, s); - my |= s; - ++d; - } - else - my = mx; - return std::make_pair( - mx, (d < 31) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1); -} - -/// Postprocessing for binary exponential. -/// \tparam R rounding mode to use -/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results -/// \param m mantissa as Q1.31 -/// \param exp absolute value of unbiased exponent -/// \param esign sign of actual exponent -/// \param sign sign bit of result -/// \return value converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded or \a I is `true` -template -unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0) -{ - int s = 0; - if(esign) - { - if(m > 0x80000000) - { - m = divide64(0x80000000, m, s); - ++exp; - } - if(exp > 25) - return underflow(sign); - else if(exp == 25) - return rounded(sign, 1, (m & 0x7FFFFFFF) != 0); - exp = -exp; - } - else if(exp > 15) - return overflow(sign); - return fixed2half(m, exp + 14, sign, s); -} - -/// Postprocessing for binary logarithm. -/// \tparam R rounding mode to use -/// \tparam L logarithm for base transformation as Q1.31 -/// \param m fractional part of logarithm as Q0.31 -/// \param ilog signed integer part of logarithm -/// \param exp biased exponent of result -/// \param sign sign bit of result -/// \return value base-transformed and converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if no other exception occurred -template -unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) -{ - uint32 msign = sign_mask(ilog); - m = (((static_cast(ilog) << 27) + (m >> 4)) ^ msign) - msign; - if(!m) - return 0; - for(; m < 0x80000000; m <<= 1, --exp) - ; - int i = m >= L, s; - exp += i; - m >>= 1 + i; - sign ^= msign & 0x8000; - if(exp < -11) - return underflow(sign); - m = divide64(m, L, s); - return fixed2half(m, exp, sign, 1); -} - -/// Hypotenuse square root and postprocessing. -/// \tparam R rounding mode to use -/// \param r mantissa as Q2.30 -/// \param exp unbiased exponent -/// \return square root converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if value had to be rounded -template -unsigned int hypot_post(uint32 r, int exp) -{ - int i = r >> 31; - if((exp += i) > 46) - return overflow(); - if(exp < -34) - return underflow(); - r = (r >> i) | (r & i); - uint32 m = sqrt<30>(r, exp += 15); - return fixed2half(m, exp - 1, 0, r != 0); -} - -/// Division and postprocessing for tangents. -/// \tparam R rounding mode to use -/// \param my dividend as Q1.31 -/// \param mx divisor as Q1.31 -/// \param exp biased exponent of result -/// \param sign sign bit of result -/// \return quotient converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if no other exception occurred -template -unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) -{ - int i = my >= mx, s; - exp += i; - if(exp > 29) - return overflow(sign); - if(exp < -11) - return underflow(sign); - uint32 m = divide64(my >> (i + 1), mx, s); - return fixed2half(m, exp, sign, s); -} - -/// Area function and postprocessing. -/// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = -/// log(x+sqrt(x^2+|-1))`. -/// \tparam R rounding mode to use -/// \tparam S `true` for asinh, `false` for acosh -/// \param arg half-precision argument -/// \return asinh|acosh(\a arg) converted to half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if no other exception occurred -template -unsigned int area(unsigned int arg) -{ - int abs = arg & 0x7FFF, expx = (abs >> 10) + (abs <= 0x3FF) - 15, expy = -15, ilog, i; - uint32 mx = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) << 20, my, r; - for(; abs < 0x400; abs <<= 1, --expy) - ; - expy += abs >> 10; - r = ((abs & 0x3FF) | 0x400) << 5; - r *= r; - i = r >> 31; - expy = 2 * expy + i; - r >>= i; - if(S) - { - if(expy < 0) - { - r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | - ((r & ((static_cast(1) << -expy) - 1)) != 0)) - : 1); - expy = 0; - } - else - { - r += 0x40000000 >> expy; - i = r >> 31; - r = (r >> i) | (r & i); - expy += i; - } - } - else - { - r -= 0x40000000 >> expy; - for(; r < 0x40000000; r <<= 1, --expy) - ; - } - my = sqrt<30>(r, expy); - my = (my << 15) + (r << 14) / my; - if(S) - { - mx >>= expy - expx; - ilog = expy; - } - else - { - my >>= expx - expy; - ilog = expx; - } - my += mx; - i = my >> 31; - static const int G = S && (R == std::round_to_nearest); - return log2_post( - log2(my >> i, 26 + S + G) + (G << 3), ilog + i, 17, arg & (static_cast(S) << 15)); -} - -/// Class for 1.31 unsigned floating-point computation -struct f31 -{ - /// Constructor. - /// \param mant mantissa as 1.31 - /// \param e exponent - HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} - - /// Constructor. - /// \param abs unsigned half-precision value - f31(unsigned int abs) : exp(-15) - { - for(; abs < 0x400; abs <<= 1, --exp) - ; - m = static_cast((abs & 0x3FF) | 0x400) << 21; - exp += (abs >> 10); - } - - /// Addition operator. - /// \param a first operand - /// \param b second operand - /// \return \a a + \a b - friend f31 operator+(f31 a, f31 b) - { - if(b.exp > a.exp) - std::swap(a, b); - int d = a.exp - b.exp; - uint32 m = a.m + ((d < 32) ? (b.m >> d) : 0); - int i = (m & 0xFFFFFFFF) < a.m; - return f31(((m + i) >> i) | 0x80000000, a.exp + i); - } - - /// Subtraction operator. - /// \param a first operand - /// \param b second operand - /// \return \a a - \a b - friend f31 operator-(f31 a, f31 b) - { - int d = a.exp - b.exp, exp = a.exp; - uint32 m = a.m - ((d < 32) ? (b.m >> d) : 0); - if(!m) - return f31(0, -32); - for(; m < 0x80000000; m <<= 1, --exp) - ; - return f31(m, exp); - } - - /// Multiplication operator. - /// \param a first operand - /// \param b second operand - /// \return \a a * \a b - friend f31 operator*(f31 a, f31 b) - { - uint32 m = multiply64(a.m, b.m); - int i = m >> 31; - return f31(m << (1 - i), a.exp + b.exp + i); - } - - /// Division operator. - /// \param a first operand - /// \param b second operand - /// \return \a a / \a b - friend f31 operator/(f31 a, f31 b) - { - int i = a.m >= b.m, s; - uint32 m = divide64((a.m + i) >> i, b.m, s); - return f31(m, a.exp - b.exp + i - 1); - } - - uint32 m; ///< mantissa as 1.31. - int exp; ///< exponent. -}; - -/// Error function and postprocessing. -/// This computes the value directly in Q1.31 using the approximations given -/// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). -/// \tparam R rounding mode to use -/// \tparam C `true` for comlementary error function, `false` else -/// \param arg half-precision function argument -/// \return approximated value of error function in half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if no other exception occurred -template -unsigned int erf(unsigned int arg) -{ - unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; - f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), - t = f31(0x80000000, 0) / (f31(0x80000000, 0) + f31(0xA7BA054A, -2) * x), t2 = t * t; - f31 e = ((f31(0x87DC2213, 0) * t2 + f31(0xB5F0E2AE, 0)) * t2 + f31(0x82790637, -2) - - (f31(0xBA00E2B8, 0) * t2 + f31(0x91A98E62, -2)) * t) * - t / - ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) - : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp))); - return (!C || sign) - ? fixed2half( - 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) - : (e.exp < -25) - ? underflow() - : fixed2half(e.m >> 1, e.exp + 14, 0, e.m & 1); -} - -/// Gamma function and postprocessing. -/// This approximates the value of either the gamma function or its logarithm directly in Q1.31. -/// \tparam R rounding mode to use -/// \tparam L `true` for lograithm of gamma function, `false` for gamma function -/// \param arg half-precision floating-point value -/// \return lgamma/tgamma(\a arg) in half-precision -/// \exception FE_OVERFLOW on overflows -/// \exception FE_UNDERFLOW on underflows -/// \exception FE_INEXACT if \a arg is not a positive integer -template -unsigned int gamma(unsigned int arg) -{ - /* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, - -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, - 0.0114684895434781459556 }; double t = arg + 4.65, s = p[0]; for(unsigned int i=0; i<5; ++i) - s += p[i+1] / (arg+i); - return std::log(s) + (arg-0.5)*std::log(t) - t; -*/ static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); - unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; - bool bsign = sign != 0; - f31 z(abs), x = sign ? (z + f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), - s = f31(0xA06C9901, 1) + f31(0xBBE654E2, -7) / (x + f31(0x80000000, 2)) + - f31(0xA1CE6098, 6) / (x + f31(0x80000000, 1)) + f31(0xE1868CB7, 7) / x - - f31(0x8625E279, 8) / (x + f31(0x80000000, 0)) - - f31(0xA03E158F, 2) / (x + f31(0xC0000000, 1)); - int i = (s.exp >= 2) + (s.exp >= 4) + (s.exp >= 8) + (s.exp >= 16); - s = f31((static_cast(s.exp) << (31 - i)) + (log2(s.m >> 1, 28) >> i), i) / lbe; - if(x.exp != -1 || x.m != 0x80000000) - { - i = (t.exp >= 2) + (t.exp >= 4) + (t.exp >= 8); - f31 l = f31((static_cast(t.exp) << (31 - i)) + (log2(t.m >> 1, 30) >> i), i) / lbe; - s = (x.exp < -1) ? (s - (f31(0x80000000, -1) - x) * l) - : (s + (x - f31(0x80000000, -1)) * l); - } - s = x.exp ? (s - t) : (t - s); - if(bsign) - { - if(z.exp >= 0) - { - sign &= (L | ((z.m >> (31 - z.exp)) & 1)) - 1; - for(z = f31((z.m << (1 + z.exp)) & 0xFFFFFFFF, -1); z.m < 0x80000000; - z.m <<= 1, --z.exp) - ; - } - if(z.exp == -1) - z = f31(0x80000000, 0) - z; - if(z.exp < -1) - { - z = z * pi; - z.m = sincos(z.m >> (1 - z.exp), 30).first; - for(z.exp = 1; z.m < 0x80000000; z.m <<= 1, --z.exp) - ; - } - else - z = f31(0x80000000, 0); - } - if(L) - { - if(bsign) - { - f31 l(0x92868247, 0); - if(z.exp < 0) - { - uint32 m = log2((z.m + 1) >> 1, 27); - z = f31(-((static_cast(z.exp) << 26) + (m >> 5)), 5); - for(; z.m < 0x80000000; z.m <<= 1, --z.exp) - ; - l = l + z / lbe; - } - sign = static_cast(x.exp && (l.exp < s.exp || (l.exp == s.exp && l.m < s.m))) - << 15; - s = sign ? (s - l) : x.exp ? (l - s) : (l + s); - } - else - { - sign = static_cast(x.exp == 0) << 15; - if(s.exp < -24) - return underflow(sign); - if(s.exp > 15) - return overflow(sign); - } - } - else - { - s = s * lbe; - uint32 m; - if(s.exp < 0) - { - m = s.m >> -s.exp; - s.exp = 0; - } - else - { - m = (s.m << s.exp) & 0x7FFFFFFF; - s.exp = (s.m >> (31 - s.exp)); - } - s.m = exp2(m, 27); - if(!x.exp) - s = f31(0x80000000, 0) / s; - if(bsign) - { - if(z.exp < 0) - s = s * z; - s = pi / s; - if(s.exp < -24) - return underflow(sign); - } - else if(z.exp > 0 && !(z.m & ((1 << (31 - z.exp)) - 1))) - return ((s.exp + 14) << 10) + (s.m >> 21); - if(s.exp > 15) - return overflow(sign); - } - return fixed2half(s.m, s.exp + 14, sign); -} -/// \} - -template -struct half_caster; -} // namespace detail - -/// Half-precision floating-point type. -/// This class implements an IEEE-conformant half-precision floating-point type with the usual -/// arithmetic -/// operators and conversions. It is implicitly convertible to single-precision floating-point, -/// which makes artihmetic -/// expressions and functions with mixed-type operands to be of the most precise operand type. -/// -/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's -/// less strict and -/// extended definitions it is both a standard layout type and a trivially copyable type (even if -/// not a POD type), which -/// means it can be standard-conformantly copied using raw binary copies. But in this context some -/// more words about the -/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not -/// neccessarily have to be of -/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of -/// this type will most -/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of -/// the underlying 16-bit -/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an -/// actual size of 16 bits if -/// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this -/// should be the case on -/// nearly any reasonable platform. -/// -/// So if your C++ implementation is not totally exotic or imposes special alignment requirements, -/// it is a reasonable -/// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE -/// representation. -class half -{ - public: - /// \name Construction and assignment - /// \{ - - /// Default constructor. - /// This initializes the half to 0. Although this does not match the builtin types' - /// default-initialization semantics - /// and may be less efficient than no initialization, it is needed to provide proper - /// value-initialization semantics. - HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} - - /// Conversion constructor. - /// \param rhs float to convert - /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding - explicit half(float rhs) - : data_(static_cast(detail::float2half(rhs))) - { - } - - /// Conversion to single-precision. - /// \return single precision value representing expression value - operator float() const { return detail::half2float(data_); } - - /// Assignment operator. - /// \param rhs single-precision value to copy from - /// \return reference to this half - /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding - half& operator=(float rhs) - { - data_ = static_cast(detail::float2half(rhs)); - return *this; - } - - /// \} - /// \name Arithmetic updates - /// \{ - - /// Arithmetic assignment. - /// \tparam T type of concrete half expression - /// \param rhs half expression to add - /// \return reference to this half - /// \exception FE_... according to operator+(half,half) - half& operator+=(half rhs) { return *this = *this + rhs; } - - /// Arithmetic assignment. - /// \tparam T type of concrete half expression - /// \param rhs half expression to subtract - /// \return reference to this half - /// \exception FE_... according to operator-(half,half) - half& operator-=(half rhs) { return *this = *this - rhs; } - - /// Arithmetic assignment. - /// \tparam T type of concrete half expression - /// \param rhs half expression to multiply with - /// \return reference to this half - /// \exception FE_... according to operator*(half,half) - half& operator*=(half rhs) { return *this = *this * rhs; } - - /// Arithmetic assignment. - /// \tparam T type of concrete half expression - /// \param rhs half expression to divide by - /// \return reference to this half - /// \exception FE_... according to operator/(half,half) - half& operator/=(half rhs) { return *this = *this / rhs; } - - /// Arithmetic assignment. - /// \param rhs single-precision value to add - /// \return reference to this half - /// \exception FE_... according to operator=() - half& operator+=(float rhs) { return *this = *this + rhs; } - - /// Arithmetic assignment. - /// \param rhs single-precision value to subtract - /// \return reference to this half - /// \exception FE_... according to operator=() - half& operator-=(float rhs) { return *this = *this - rhs; } - - /// Arithmetic assignment. - /// \param rhs single-precision value to multiply with - /// \return reference to this half - /// \exception FE_... according to operator=() - half& operator*=(float rhs) { return *this = *this * rhs; } - - /// Arithmetic assignment. - /// \param rhs single-precision value to divide by - /// \return reference to this half - /// \exception FE_... according to operator=() - half& operator/=(float rhs) { return *this = *this / rhs; } - - /// \} - /// \name Increment and decrement - /// \{ - - /// Prefix increment. - /// \return incremented half value - /// \exception FE_... according to operator+(half,half) - half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } - - /// Prefix decrement. - /// \return decremented half value - /// \exception FE_... according to operator-(half,half) - half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } - - /// Postfix increment. - /// \return non-incremented half value - /// \exception FE_... according to operator+(half,half) - half operator++(int) - { - half out(*this); - ++*this; - return out; - } - - /// Postfix decrement. - /// \return non-decremented half value - /// \exception FE_... according to operator-(half,half) - half operator--(int) - { - half out(*this); - --*this; - return out; - } - /// \} - - private: - /// Rounding mode to use - static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); - - /// Constructor. - /// \param bits binary representation to set half to - HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT - : data_(static_cast(bits)) - { - } - - /// Internal binary representation - detail::uint16 data_; - -#ifndef HALF_DOXYGEN_ONLY - friend HALF_CONSTEXPR_NOERR bool operator==(half, half); - friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); - friend HALF_CONSTEXPR_NOERR bool operator<(half, half); - friend HALF_CONSTEXPR_NOERR bool operator>(half, half); - friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); - friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); - friend HALF_CONSTEXPR half operator-(half); - friend half operator+(half, half); - friend half operator-(half, half); - friend half operator*(half, half); - friend half operator/(half, half); - template - friend std::basic_ostream& operator<<(std::basic_ostream&, half); - template - friend std::basic_istream& operator>>(std::basic_istream&, half&); - friend HALF_CONSTEXPR half fabs(half); - friend half fmod(half, half); - friend half remainder(half, half); - friend half remquo(half, half, int*); - friend half fma(half, half, half); - friend HALF_CONSTEXPR_NOERR half fmax(half, half); - friend HALF_CONSTEXPR_NOERR half fmin(half, half); - friend half fdim(half, half); - friend half nanh(const char*); - friend half exp(half); - friend half exp2(half); - friend half expm1(half); - friend half log(half); - friend half log10(half); - friend half log2(half); - friend half log1p(half); - friend half sqrt(half); - friend half cbrt(half); - friend half hypot(half, half); - friend half hypot(half, half, half); - friend half pow(half, half); - friend void sincos(half, half*, half*); - friend half sin(half); - friend half cos(half); - friend half tan(half); - friend half asin(half); - friend half acos(half); - friend half atan(half); - friend half atan2(half, half); - friend half sinh(half); - friend half cosh(half); - friend half tanh(half); - friend half asinh(half); - friend half acosh(half); - friend half atanh(half); - friend half erf(half); - friend half erfc(half); - friend half lgamma(half); - friend half tgamma(half); - friend half ceil(half); - friend half floor(half); - friend half trunc(half); - friend half round(half); - friend long lround(half); - friend half rint(half); - friend long lrint(half); - friend half nearbyint(half); -#ifdef HALF_ENABLE_CPP11_LONG_LONG - friend long long llround(half); - friend long long llrint(half); -#endif - friend half frexp(half, int*); - friend half scalbln(half, long); - friend half modf(half, half*); - friend int ilogb(half); - friend half logb(half); - friend half nextafter(half, half); - friend half nexttoward(half, long double); - friend HALF_CONSTEXPR half copysign(half, half); - friend HALF_CONSTEXPR int fpclassify(half); - friend HALF_CONSTEXPR bool isfinite(half); - friend HALF_CONSTEXPR bool isinf(half); - friend HALF_CONSTEXPR bool isnan(half); - friend HALF_CONSTEXPR bool isnormal(half); - friend HALF_CONSTEXPR bool signbit(half); - friend HALF_CONSTEXPR bool isgreater(half, half); - friend HALF_CONSTEXPR bool isgreaterequal(half, half); - friend HALF_CONSTEXPR bool isless(half, half); - friend HALF_CONSTEXPR bool islessequal(half, half); - friend HALF_CONSTEXPR bool islessgreater(half, half); - template - friend struct detail::half_caster; - friend class std::numeric_limits; -#if HALF_ENABLE_CPP11_HASH - friend struct std::hash; -#endif -#if HALF_ENABLE_CPP11_USER_LITERALS - friend half literal::operator"" _h(long double); -#endif -#endif -}; - -#if HALF_ENABLE_CPP11_USER_LITERALS -namespace literal { -/// Half literal. -/// While this returns a properly rounded half-precision value, half literals can unfortunately not -/// be constant -/// expressions due to rather involved conversions. So don't expect this to be a literal literal -/// without involving -/// conversion operations at runtime. It is a convenience feature, not a performance optimization. -/// \param value literal value -/// \return half with of given value (possibly rounded) -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator"" _h(long double value) -{ - return half(detail::binary, detail::float2half(value)); -} -} // namespace literal -#endif - -namespace detail { -/// Helper class for half casts. -/// This class template has to be specialized for all valid cast arguments to define an appropriate -/// static -/// `cast` member function and a corresponding `type` member denoting its return type. -/// \tparam T destination type -/// \tparam U source type -/// \tparam R rounding mode to use -template -struct half_caster -{ -}; -template -struct half_caster -{ -#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS - static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); -#endif - - static half cast(U arg) { return cast_impl(arg, is_float()); }; - - private: - static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } - static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } -}; -template -struct half_caster -{ -#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS - static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); -#endif - - static T cast(half arg) { return cast_impl(arg, is_float()); } - - private: - static T cast_impl(half arg, true_type) { return half2float(arg.data_); } - static T cast_impl(half arg, false_type) { return half2int(arg.data_); } -}; -template -struct half_caster -{ - static half cast(half arg) { return arg; } -}; -} // namespace detail -} // namespace half_float - -/// Extensions to the C++ standard library. -namespace std { -/// Numeric limits for half-precision floats. -/// **See also:** Documentation for -/// [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) -template <> -class numeric_limits -{ - public: - /// Is template specialization. - static HALF_CONSTEXPR_CONST bool is_specialized = true; - - /// Supports signed values. - static HALF_CONSTEXPR_CONST bool is_signed = true; - - /// Is not an integer type. - static HALF_CONSTEXPR_CONST bool is_integer = false; - - /// Is not exact. - static HALF_CONSTEXPR_CONST bool is_exact = false; - - /// Doesn't provide modulo arithmetic. - static HALF_CONSTEXPR_CONST bool is_modulo = false; - - /// Has a finite set of values. - static HALF_CONSTEXPR_CONST bool is_bounded = true; - - /// IEEE conformant. - static HALF_CONSTEXPR_CONST bool is_iec559 = true; - - /// Supports infinity. - static HALF_CONSTEXPR_CONST bool has_infinity = true; - - /// Supports quiet NaNs. - static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; - - /// Supports signaling NaNs. - static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; - - /// Supports subnormal values. - static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; - - /// Supports no denormalization detection. - static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; - -#if HALF_ERRHANDLING_THROWS - static HALF_CONSTEXPR_CONST bool traps = true; -#else - /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is - /// acitvated. - static HALF_CONSTEXPR_CONST bool traps = false; -#endif - - /// Does not support no pre-rounding underflow detection. - static HALF_CONSTEXPR_CONST bool tinyness_before = false; - - /// Rounding mode. - static HALF_CONSTEXPR_CONST float_round_style round_style = half_float::half::round_style; - - /// Significant digits. - static HALF_CONSTEXPR_CONST int digits = 11; - - /// Significant decimal digits. - static HALF_CONSTEXPR_CONST int digits10 = 3; - - /// Required decimal digits to represent all possible values. - static HALF_CONSTEXPR_CONST int max_digits10 = 5; - - /// Number base. - static HALF_CONSTEXPR_CONST int radix = 2; - - /// One more than smallest exponent. - static HALF_CONSTEXPR_CONST int min_exponent = -13; - - /// Smallest normalized representable power of 10. - static HALF_CONSTEXPR_CONST int min_exponent10 = -4; - - /// One more than largest exponent - static HALF_CONSTEXPR_CONST int max_exponent = 16; - - /// Largest finitely representable power of 10. - static HALF_CONSTEXPR_CONST int max_exponent10 = 4; - - /// Smallest positive normal value. - static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x0400); - } - - /// Smallest finite value. - static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0xFBFF); - } - - /// Largest finite value. - static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x7BFF); - } - - /// Difference between 1 and next representable value. - static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x1400); - } - - /// Maximum rounding error in ULP (units in the last place). - static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, - (round_style == std::round_to_nearest) ? 0x3800 : 0x3C00); - } - - /// Positive infinity. - static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x7C00); - } - - /// Quiet NaN. - static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x7FFF); - } - - /// Signaling NaN. - static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x7DFF); - } - - /// Smallest positive subnormal value. - static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW - { - return half_float::half(half_float::detail::binary, 0x0001); - } -}; - -#if HALF_ENABLE_CPP11_HASH -/// Hash function for half-precision floats. -/// This is only defined if C++11 `std::hash` is supported and enabled. -/// -/// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) -template <> -struct hash -{ - /// Type of function argument. - typedef half_float::half argument_type; - - /// Function return type. - typedef size_t result_type; - - /// Compute hash function. - /// \param arg half to hash - /// \return hash value - result_type operator()(argument_type arg) const - { - return hash()(arg.data_ & - -static_cast(arg.data_ != 0x8000)); - } -}; -#endif -} // namespace std - -namespace half_float { -/// \anchor compop -/// \name Comparison operators -/// \{ - -/// Comparison for equality. -/// \param x first operand -/// \param y second operand -/// \retval true if operands equal -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - (x.data_ == y.data_ || !((x.data_ | y.data_) & 0x7FFF)); -} - -/// Comparison for inequality. -/// \param x first operand -/// \param y second operand -/// \retval true if operands not equal -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) -{ - return detail::compsignal(x.data_, y.data_) || - (x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF)); -} - -/// Comparison for less than. -/// \param x first operand -/// \param y second operand -/// \retval true if \a x less than \a y -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); -} - -/// Comparison for greater than. -/// \param x first operand -/// \param y second operand -/// \retval true if \a x greater than \a y -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); -} - -/// Comparison for less equal. -/// \param x first operand -/// \param y second operand -/// \retval true if \a x less equal \a y -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); -} - -/// Comparison for greater equal. -/// \param x first operand -/// \param y second operand -/// \retval true if \a x greater equal \a y -/// \retval false else -/// \exception FE_INVALID if \a x or \a y is NaN -inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) -{ - return !detail::compsignal(x.data_, y.data_) && - ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); -} - -/// \} -/// \anchor arithmetics -/// \name Arithmetic operators -/// \{ - -/// Identity. -/// \param arg operand -/// \return unchanged operand -inline HALF_CONSTEXPR half operator+(half arg) { return arg; } - -/// Negation. -/// \param arg operand -/// \return negated operand -inline HALF_CONSTEXPR half operator-(half arg) { return half(detail::binary, arg.data_ ^ 0x8000); } - -/// Addition. -/// This operation is exact to rounding for all rounding modes. -/// \param x left operand -/// \param y right operand -/// \return sum of half expressions -/// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator+(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half( - detail::binary, - detail::float2half(detail::half2float(x.data_) + - detail::half2float(y.data_))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; - bool sub = ((x.data_ ^ y.data_) & 0x8000) != 0; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absy != 0x7C00) ? x.data_ - : (sub && absx == 0x7C00) ? detail::invalid() : y.data_); - if(!absx) - return absy ? y - : half(detail::binary, - (half::round_style == std::round_toward_neg_infinity) - ? (x.data_ | y.data_) - : (x.data_ & y.data_)); - if(!absy) - return x; - unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; - if(absy > absx) - std::swap(absx, absy); - int exp = (absx >> 10) + (absx <= 0x3FF), d = exp - (absy >> 10) - (absy <= 0x3FF), - mx = ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << 3, my; - if(d < 13) - { - my = ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << 3; - my = (my >> d) | ((my & ((1 << d) - 1)) != 0); - } - else - my = 1; - if(sub) - { - if(!(mx -= my)) - return half(detail::binary, - static_cast(half::round_style == std::round_toward_neg_infinity) - << 15); - for(; mx < 0x2000 && exp > 1; mx <<= 1, --exp) - ; - } - else - { - mx += my; - int i = mx >> 14; - if((exp += i) > 30) - return half(detail::binary, detail::overflow(sign)); - mx = (mx >> i) | (mx & i); - } - return half(detail::binary, - detail::rounded( - sign + ((exp - 1) << 10) + (mx >> 3), (mx >> 2) & 1, (mx & 0x3) != 0)); -#endif -} - -/// Subtraction. -/// This operation is exact to rounding for all rounding modes. -/// \param x left operand -/// \param y right operand -/// \return difference of half expressions -/// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator-(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half( - detail::binary, - detail::float2half(detail::half2float(x.data_) - - detail::half2float(y.data_))); -#else - return x + -y; -#endif -} - -/// Multiplication. -/// This operation is exact to rounding for all rounding modes. -/// \param x left operand -/// \param y right operand -/// \return product of half expressions -/// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator*(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half( - detail::binary, - detail::float2half(detail::half2float(x.data_) * - detail::half2float(y.data_))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; - unsigned int sign = (x.data_ ^ y.data_) & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : ((absx == 0x7C00 && !absy) || (absy == 0x7C00 && !absx)) - ? detail::invalid() - : (sign | 0x7C00)); - if(!absx || !absy) - return half(detail::binary, sign); - for(; absx < 0x400; absx <<= 1, --exp) - ; - for(; absy < 0x400; absy <<= 1, --exp) - ; - detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * - static_cast((absy & 0x3FF) | 0x400); - int i = m >> 21, s = m & i; - exp += (absx >> 10) + (absy >> 10) + i; - if(exp > 29) - return half(detail::binary, detail::overflow(sign)); - else if(exp < -11) - return half(detail::binary, detail::underflow(sign)); - return half( - detail::binary, - detail::fixed2half(m >> i, exp, sign, s)); -#endif -} - -/// Division. -/// This operation is exact to rounding for all rounding modes. -/// \param x left operand -/// \param y right operand -/// \return quotient of half expressions -/// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is -/// signaling NaN -/// \exception FE_DIVBYZERO if dividing finite value by 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half operator/(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half( - detail::binary, - detail::float2half(detail::half2float(x.data_) / - detail::half2float(y.data_))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; - unsigned int sign = (x.data_ ^ y.data_) & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absx == absy) ? detail::invalid() - : (sign | ((absx == 0x7C00) ? 0x7C00 : 0))); - if(!absx) - return half(detail::binary, absy ? sign : detail::invalid()); - if(!absy) - return half(detail::binary, detail::pole(sign)); - for(; absx < 0x400; absx <<= 1, --exp) - ; - for(; absy < 0x400; absy <<= 1, ++exp) - ; - detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; - int i = mx < my; - exp += (absx >> 10) - (absy >> 10) - i; - if(exp > 29) - return half(detail::binary, detail::overflow(sign)); - else if(exp < -11) - return half(detail::binary, detail::underflow(sign)); - mx <<= 12 + i; - my <<= 1; - return half(detail::binary, - detail::fixed2half( - mx / my, exp, sign, mx % my != 0)); -#endif -} - -/// \} -/// \anchor streaming -/// \name Input and output -/// \{ - -/// Output operator. -/// This uses the built-in functionality for streaming out floating-point numbers. -/// \param out output stream to write into -/// \param arg half expression to write -/// \return reference to output stream -template -std::basic_ostream& operator<<(std::basic_ostream& out, half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return out << detail::half2float(arg.data_); -#else - return out << detail::half2float(arg.data_); -#endif -} - -/// Input operator. -/// This uses the built-in functionality for streaming in floating-point numbers, specifically -/// double precision floating -/// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the -/// input string is first -/// rounded to double precision using the underlying platform's current floating-point rounding mode -/// before being rounded -/// to half-precision using the library's half-precision rounding mode. -/// \param in input stream to read from -/// \param arg half to read into -/// \return reference to input stream -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -template -std::basic_istream& operator>>(std::basic_istream& in, half& arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t f; -#else - double f; -#endif - if(in >> f) - arg.data_ = detail::float2half(f); - return in; -} - -/// \} -/// \anchor basic -/// \name Basic mathematical operations -/// \{ - -/// Absolute value. -/// **See also:** Documentation for -/// [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). -/// \param arg operand -/// \return absolute value of \a arg -inline HALF_CONSTEXPR half fabs(half arg) { return half(detail::binary, arg.data_ & 0x7FFF); } - -/// Absolute value. -/// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). -/// \param arg operand -/// \return absolute value of \a arg -inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } - -/// Remainder of division. -/// **See also:** Documentation for -/// [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). -/// \param x first operand -/// \param y second operand -/// \return remainder of floating-point division. -/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN -inline half fmod(half x, half y) -{ - unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absx == 0x7C00) ? detail::invalid() : x.data_); - if(!absy) - return half(detail::binary, detail::invalid()); - if(!absx) - return x; - if(absx == absy) - return half(detail::binary, sign); - return half(detail::binary, sign | detail::mod(absx, absy)); -} - -/// Remainder of division. -/// **See also:** Documentation for -/// [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). -/// \param x first operand -/// \param y second operand -/// \return remainder of floating-point division. -/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN -inline half remainder(half x, half y) -{ - unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absx == 0x7C00) ? detail::invalid() : x.data_); - if(!absy) - return half(detail::binary, detail::invalid()); - if(absx == absy) - return half(detail::binary, sign); - return half(detail::binary, sign ^ detail::mod(absx, absy)); -} - -/// Remainder of division. -/// **See also:** Documentation for -/// [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). -/// \param x first operand -/// \param y second operand -/// \param quo address to store some bits of quotient at -/// \return remainder of floating-point division. -/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN -inline half remquo(half x, half y, int* quo) -{ - unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absx == 0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); - if(!absy) - return half(detail::binary, detail::invalid()); - bool qsign = ((value ^ y.data_) & 0x8000) != 0; - int q = 1; - if(absx != absy) - value ^= detail::mod(absx, absy, &q); - return *quo = qsign ? -q : q, half(detail::binary, value); -} - -/// Fused multiply add. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). -/// \param x first operand -/// \param y second operand -/// \param z third operand -/// \return ( \a x * \a y ) + \a z rounded as one operation. -/// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet -/// NaN and no argument is a signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition -inline half fma(half x, half y, half z) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t fx = detail::half2float(x.data_), - fy = detail::half2float(y.data_), - fz = detail::half2float(z.data_); -#if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA - return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); -#else - return half(detail::binary, detail::float2half(fx * fy + fz)); -#endif -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; - unsigned int sign = (x.data_ ^ y.data_) & 0x8000; - bool sub = ((sign ^ z.data_) & 0x8000) != 0; - if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) - return (absx > 0x7C00 || absy > 0x7C00 || absz > 0x7C00) - ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) - : (absx == 0x7C00) ? half(detail::binary, - (!absy || (sub && absz == 0x7C00)) ? detail::invalid() - : (sign | 0x7C00)) - : (absy == 0x7C00) ? half(detail::binary, - (!absx || (sub && absz == 0x7C00)) - ? detail::invalid() - : (sign | 0x7C00)) - : z; - if(!absx || !absy) - return absz - ? z - : half(detail::binary, - (half::round_style == std::round_toward_neg_infinity) ? (z.data_ | sign) - : (z.data_ & sign)); - for(; absx < 0x400; absx <<= 1, --exp) - ; - for(; absy < 0x400; absy <<= 1, --exp) - ; - detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * - static_cast((absy & 0x3FF) | 0x400); - int i = m >> 21; - exp += (absx >> 10) + (absy >> 10) + i; - m <<= 3 - i; - if(absz) - { - int expz = 0; - for(; absz < 0x400; absz <<= 1, --expz) - ; - expz += absz >> 10; - detail::uint32 mz = static_cast((absz & 0x3FF) | 0x400) << 13; - if(expz > exp || (expz == exp && mz > m)) - { - std::swap(m, mz); - std::swap(exp, expz); - if(sub) - sign = z.data_ & 0x8000; - } - int d = exp - expz; - mz = (d < 23) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; - if(sub) - { - m = m - mz; - if(!m) - return half( - detail::binary, - static_cast(half::round_style == std::round_toward_neg_infinity) - << 15); - for(; m < 0x800000; m <<= 1, --exp) - ; - } - else - { - m += mz; - i = m >> 24; - m = (m >> i) | (m & i); - exp += i; - } - } - if(exp > 30) - return half(detail::binary, detail::overflow(sign)); - else if(exp < -10) - return half(detail::binary, detail::underflow(sign)); - return half(detail::binary, - detail::fixed2half(m, exp - 1, sign)); -#endif -} - -/// Maximum of half expressions. -/// **See also:** Documentation for -/// [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). -/// \param x first operand -/// \param y second operand -/// \return maximum of operands, ignoring quiet NaNs -/// \exception FE_INVALID if \a x or \a y is signaling NaN -inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) -{ - return half(detail::binary, - (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < - (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) - ? detail::select(y.data_, x.data_) - : detail::select(x.data_, y.data_)); -} - -/// Minimum of half expressions. -/// **See also:** Documentation for -/// [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). -/// \param x first operand -/// \param y second operand -/// \return minimum of operands, ignoring quiet NaNs -/// \exception FE_INVALID if \a x or \a y is signaling NaN -inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) -{ - return half(detail::binary, - (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > - (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) - ? detail::select(y.data_, x.data_) - : detail::select(x.data_, y.data_)); -} - -/// Positive difference. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). -/// \param x first operand -/// \param y second operand -/// \return \a x - \a y or 0 if difference negative -/// \exception FE_... according to operator-(half,half) -inline half fdim(half x, half y) -{ - if(isnan(x) || isnan(y)) - return half(detail::binary, detail::signal(x.data_, y.data_)); - return (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) <= - (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) - ? half(detail::binary, 0) - : (x - y); -} - -/// Get NaN value. -/// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). -/// \param arg string code -/// \return quiet NaN -inline half nanh(const char* arg) -{ - unsigned int value = 0x7FFF; - while(*arg) - value ^= static_cast(*arg++) & 0xFF; - return half(detail::binary, value); -} - -/// \} -/// \anchor exponential -/// \name Exponential functions -/// \{ - -/// Exponential function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). -/// \param arg function argument -/// \return e raised to \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half exp(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::exp(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) - : detail::signal(arg.data_)); - if(abs >= 0x4C80) - return half(detail::binary, - (arg.data_ & 0x8000) ? detail::underflow() - : detail::overflow()); - detail::uint32 m = detail::multiply64( - static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); - int e = (abs >> 10) + (abs <= 0x3FF), exp; - if(e < 14) - { - exp = 0; - m >>= 14 - e; - } - else - { - exp = m >> (45 - e); - m = (m << (e - 14)) & 0x7FFFFFFF; - } - return half(detail::binary, - detail::exp2_post( - detail::exp2(m, 26), exp, (arg.data_ & 0x8000) != 0)); -#endif -} - -/// Binary exponential. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). -/// \param arg function argument -/// \return 2 raised to \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half exp2(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::exp2(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) - : detail::signal(arg.data_)); - if(abs >= 0x4E40) - return half(detail::binary, - (arg.data_ & 0x8000) ? detail::underflow() - : detail::overflow()); - int e = (abs >> 10) + (abs <= 0x3FF), exp = (abs & 0x3FF) + ((abs > 0x3FF) << 10); - detail::uint32 m = detail::exp2((static_cast(exp) << (6 + e)) & 0x7FFFFFFF, 28); - exp >>= 25 - e; - if(m == 0x80000000) - { - if(arg.data_ & 0x8000) - exp = -exp; - else if(exp > 15) - return half(detail::binary, detail::overflow()); - return half(detail::binary, - detail::fixed2half(m, exp + 14)); - } - return half(detail::binary, - detail::exp2_post(m, exp, (arg.data_ & 0x8000) != 0)); -#endif -} - -/// Exponential minus one. -/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for -/// `std::round_to_nearest` -/// and in <1% of inputs for any other rounding mode. -/// -/// **See also:** Documentation for -/// [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). -/// \param arg function argument -/// \return e raised to \a arg and subtracted by 1 -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half expm1(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::expm1(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? (0x7C00 + (sign >> 1)) : detail::signal(arg.data_)); - if(abs >= 0x4A00) - return half(detail::binary, - (arg.data_ & 0x8000) ? detail::rounded(0xBBFF, 1, 1) - : detail::overflow()); - detail::uint32 m = detail::multiply64( - static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); - int e = (abs >> 10) + (abs <= 0x3FF), exp; - if(e < 14) - { - exp = 0; - m >>= 14 - e; - } - else - { - exp = m >> (45 - e); - m = (m << (e - 14)) & 0x7FFFFFFF; - } - m = detail::exp2(m); - if(sign) - { - int s = 0; - if(m > 0x80000000) - { - ++exp; - m = detail::divide64(0x80000000, m, s); - } - m = 0x80000000 - - ((m >> exp) | ((m & ((static_cast(1) << exp) - 1)) != 0) | s); - exp = 0; - } - else - m -= (exp < 31) ? (0x80000000 >> exp) : 1; - for(exp += 14; m < 0x80000000 && exp; m <<= 1, --exp) - ; - if(exp > 29) - return half(detail::binary, detail::overflow()); - return half(detail::binary, - detail::rounded( - sign + (exp << 10) + (m >> 21), (m >> 20) & 1, (m & 0xFFFFF) != 0)); -#endif -} - -/// Natural logarithm. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). -/// \param arg function argument -/// \return logarithm of \a arg to base e -/// \exception FE_INVALID for signaling NaN or negative argument -/// \exception FE_DIVBYZERO for 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half log(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::log(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = -15; - if(!abs) - return half(detail::binary, detail::pole(0x8000)); - if(arg.data_ & 0x8000) - return half(detail::binary, - (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs >= 0x7C00) - return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - return half(detail::binary, - detail::log2_post( - detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, - exp, - 17)); -#endif -} - -/// Common logarithm. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). -/// \param arg function argument -/// \return logarithm of \a arg to base 10 -/// \exception FE_INVALID for signaling NaN or negative argument -/// \exception FE_DIVBYZERO for 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half log10(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::log10(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = -15; - if(!abs) - return half(detail::binary, detail::pole(0x8000)); - if(arg.data_ & 0x8000) - return half(detail::binary, - (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs >= 0x7C00) - return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); - switch(abs) - { - case 0x4900: return half(detail::binary, 0x3C00); - case 0x5640: return half(detail::binary, 0x4000); - case 0x63D0: return half(detail::binary, 0x4200); - case 0x70E2: return half(detail::binary, 0x4400); - } - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - return half(detail::binary, - detail::log2_post( - detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, - exp, - 16)); -#endif -} - -/// Binary logarithm. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). -/// \param arg function argument -/// \return logarithm of \a arg to base 2 -/// \exception FE_INVALID for signaling NaN or negative argument -/// \exception FE_DIVBYZERO for 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half log2(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::log2(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; - if(!abs) - return half(detail::binary, detail::pole(0x8000)); - if(arg.data_ & 0x8000) - return half(detail::binary, - (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs >= 0x7C00) - return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); - if(abs == 0x3C00) - return half(detail::binary, 0); - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += (abs >> 10); - if(!(abs & 0x3FF)) - { - unsigned int value = static_cast(exp < 0) << 15, m = std::abs(exp) << 6; - for(exp = 18; m < 0x400; m <<= 1, --exp) - ; - return half(detail::binary, value + (exp << 10) + m); - } - detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), - m = (((ilog << 27) + - (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, - 28) >> - 4)) ^ - sign) - - sign; - if(!m) - return half(detail::binary, 0); - for(exp = 14; m < 0x8000000 && exp; m <<= 1, --exp) - ; - for(; m > 0xFFFFFFF; m >>= 1, ++exp) - s |= m & 1; - return half( - detail::binary, - detail::fixed2half(m, exp, sign & 0x8000, s)); -#endif -} - -/// Natural logarithm plus one. -/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for -/// `std::round_to_nearest` -/// and in ~1% of inputs for any other rounding mode. -/// -/// **See also:** Documentation for -/// [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). -/// \param arg function argument -/// \return logarithm of \a arg plus 1 to base e -/// \exception FE_INVALID for signaling NaN or argument <-1 -/// \exception FE_DIVBYZERO for -1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half log1p(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::log1p(detail::half2float(arg.data_)))); -#else - if(arg.data_ >= 0xBC00) - return half(detail::binary, - (arg.data_ == 0xBC00) - ? detail::pole(0x8000) - : (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); - int abs = arg.data_ & 0x7FFF, exp = -15; - if(!abs || abs >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - detail::uint32 m = static_cast((abs & 0x3FF) | 0x400) << 20; - if(arg.data_ & 0x8000) - { - m = 0x40000000 - (m >> -exp); - for(exp = 0; m < 0x40000000; m <<= 1, --exp) - ; - } - else - { - if(exp < 0) - { - m = 0x40000000 + (m >> -exp); - exp = 0; - } - else - { - m += 0x40000000 >> exp; - int i = m >> 31; - m >>= i; - exp += i; - } - } - return half(detail::binary, - detail::log2_post(detail::log2(m), exp, 17)); -#endif -} - -/// \} -/// \anchor power -/// \name Power functions -/// \{ - -/// Square root. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). -/// \param arg function argument -/// \return square root of \a arg -/// \exception FE_INVALID for signaling NaN and negative arguments -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half sqrt(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::sqrt(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = 15; - if(!abs || arg.data_ >= 0x7C00) - return half(detail::binary, - (abs > 0x7C00) ? detail::signal(arg.data_) - : (arg.data_ > 0x8000) ? detail::invalid() : arg.data_); - for(; abs < 0x400; abs <<= 1, --exp) - ; - detail::uint32 r = static_cast((abs & 0x3FF) | 0x400) << 10, - m = detail::sqrt<20>(r, exp += abs >> 10); - return half( - detail::binary, - detail::rounded((exp << 10) + (m & 0x3FF), r > m, r != 0)); -#endif -} - -/// Cubic root. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). -/// \param arg function argument -/// \return cubic root of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half cbrt(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::cbrt(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = -15; - if(!abs || abs == 0x3C00 || abs >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - for(; abs < 0x400; abs <<= 1, --exp) - ; - detail::uint32 ilog = exp + (abs >> 10), sign = detail::sign_mask(ilog), f, - m = (((ilog << 27) + - (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, - 24) >> - 4)) ^ - sign) - - sign; - for(exp = 2; m < 0x80000000; m <<= 1, --exp) - ; - m = detail::multiply64(m, 0xAAAAAAAB); - int i = m >> 31, s; - exp += i; - m <<= 1 - i; - if(exp < 0) - { - f = m >> -exp; - exp = 0; - } - else - { - f = (m << exp) & 0x7FFFFFFF; - exp = m >> (31 - exp); - } - m = detail::exp2(f, (half::round_style == std::round_to_nearest) ? 29 : 26); - if(sign) - { - if(m > 0x80000000) - { - m = detail::divide64(0x80000000, m, s); - ++exp; - } - exp = -exp; - } - return half(detail::binary, - (half::round_style == std::round_to_nearest) - ? detail::fixed2half( - m, exp + 14, arg.data_ & 0x8000) - : detail::fixed2half( - (m + 0x80) >> 8, exp + 14, arg.data_ & 0x8000)); -#endif -} - -/// Hypotenuse function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). -/// \param x first argument -/// \param y second argument -/// \return square root of sum of squares without internal over- or underflows -/// \exception FE_INVALID if \a x or \a y is signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root -inline half hypot(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t fx = detail::half2float(x.data_), - fy = detail::half2float(y.data_); -#if HALF_ENABLE_CPP11_CMATH - return half(detail::binary, detail::float2half(std::hypot(fx, fy))); -#else - return half(detail::binary, - detail::float2half(std::sqrt(fx * fx + fy * fy))); -#endif -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx == 0x7C00) ? detail::select(0x7C00, y.data_) - : (absy == 0x7C00) ? detail::select(0x7C00, x.data_) - : detail::signal(x.data_, y.data_)); - if(!absx) - return half(detail::binary, absy ? detail::check_underflow(absy) : 0); - if(!absy) - return half(detail::binary, detail::check_underflow(absx)); - if(absy > absx) - std::swap(absx, absy); - for(; absx < 0x400; absx <<= 1, --expx) - ; - for(; absy < 0x400; absy <<= 1, --expy) - ; - detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; - mx *= mx; - my *= my; - int ix = mx >> 21, iy = my >> 21; - expx = 2 * (expx + (absx >> 10)) - 15 + ix; - expy = 2 * (expy + (absy >> 10)) - 15 + iy; - mx <<= 10 - ix; - my <<= 10 - iy; - int d = expx - expy; - my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; - return half(detail::binary, detail::hypot_post(mx + my, expx)); -#endif -} - -/// Hypotenuse function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). -/// \param x first argument -/// \param y second argument -/// \param z third argument -/// \return square root of sum of squares without internal over- or underflows -/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root -inline half hypot(half x, half y, half z) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t fx = detail::half2float(x.data_), - fy = detail::half2float(y.data_), - fz = detail::half2float(z.data_); - return half(detail::binary, - detail::float2half(std::sqrt(fx * fx + fy * fy + fz * fz))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, - expy = 0, expz = 0; - if(!absx) - return hypot(y, z); - if(!absy) - return hypot(x, z); - if(!absz) - return hypot(x, y); - if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) - return half(detail::binary, - (absx == 0x7C00) - ? detail::select(0x7C00, detail::select(y.data_, z.data_)) - : (absy == 0x7C00) - ? detail::select(0x7C00, detail::select(x.data_, z.data_)) - : (absz == 0x7C00) - ? detail::select(0x7C00, detail::select(x.data_, y.data_)) - : detail::signal(x.data_, y.data_, z.data_)); - if(absz > absy) - std::swap(absy, absz); - if(absy > absx) - std::swap(absx, absy); - if(absz > absy) - std::swap(absy, absz); - for(; absx < 0x400; absx <<= 1, --expx) - ; - for(; absy < 0x400; absy <<= 1, --expy) - ; - for(; absz < 0x400; absz <<= 1, --expz) - ; - detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400, - mz = (absz & 0x3FF) | 0x400; - mx *= mx; - my *= my; - mz *= mz; - int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; - expx = 2 * (expx + (absx >> 10)) - 15 + ix; - expy = 2 * (expy + (absy >> 10)) - 15 + iy; - expz = 2 * (expz + (absz >> 10)) - 15 + iz; - mx <<= 10 - ix; - my <<= 10 - iy; - mz <<= 10 - iz; - int d = expy - expz; - mz = (d < 30) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; - my += mz; - if(my & 0x80000000) - { - my = (my >> 1) | (my & 1); - if(++expy > expx) - { - std::swap(mx, my); - std::swap(expx, expy); - } - } - d = expx - expy; - my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; - return half(detail::binary, detail::hypot_post(mx + my, expx)); -#endif -} - -/// Power function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in -/// ~0.00025% of inputs. -/// -/// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). -/// \param x base -/// \param y exponent -/// \return \a x raised to \a y -/// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y -/// is finite and not integral -/// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half pow(half x, half y) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::pow(detail::half2float(x.data_), - detail::half2float(y.data_)))); -#else - int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; - if(!absy || x.data_ == 0x3C00) - return half(detail::binary, - detail::select(0x3C00, (x.data_ == 0x3C00) ? y.data_ : x.data_)); - bool is_int = absy >= 0x6400 || (absy >= 0x3C00 && !(absy & ((1 << (25 - (absy >> 10))) - 1))); - unsigned int sign = - x.data_ & - (static_cast((absy < 0x6800) && is_int && ((absy >> (25 - (absy >> 10))) & 1)) - << 15); - if(absx >= 0x7C00 || absy >= 0x7C00) - return half(detail::binary, - (absx > 0x7C00 || absy > 0x7C00) - ? detail::signal(x.data_, y.data_) - : (absy == 0x7C00) - ? ((absx == 0x3C00) - ? 0x3C00 - : (!absx && y.data_ == 0xFC00) - ? detail::pole() - : (0x7C00 & -((y.data_ >> 15) ^ (absx > 0x3C00)))) - : (sign | (0x7C00 & ((y.data_ >> 15) - 1U)))); - if(!absx) - return half(detail::binary, (y.data_ & 0x8000) ? detail::pole(sign) : sign); - if((x.data_ & 0x8000) && !is_int) - return half(detail::binary, detail::invalid()); - if(x.data_ == 0xBC00) - return half(detail::binary, sign | 0x3C00); - if(y.data_ == 0x3800) - return sqrt(x); - if(y.data_ == 0x3C00) - return half(detail::binary, detail::check_underflow(x.data_)); - if(y.data_ == 0x4000) - return x * x; - for(; absx < 0x400; absx <<= 1, --exp) - ; - detail::uint32 ilog = exp + (absx >> 10), msign = detail::sign_mask(ilog), f, - m = (((ilog << 27) + - ((detail::log2(static_cast((absx & 0x3FF) | 0x400) << 20) + - 8) >> - 4)) ^ - msign) - - msign; - for(exp = -11; m < 0x80000000; m <<= 1, --exp) - ; - for(; absy < 0x400; absy <<= 1, --exp) - ; - m = detail::multiply64(m, static_cast((absy & 0x3FF) | 0x400) << 21); - int i = m >> 31; - exp += (absy >> 10) + i; - m <<= 1 - i; - if(exp < 0) - { - f = m >> -exp; - exp = 0; - } - else - { - f = (m << exp) & 0x7FFFFFFF; - exp = m >> (31 - exp); - } - return half(detail::binary, - detail::exp2_post( - detail::exp2(f), exp, ((msign & 1) ^ (y.data_ >> 15)) != 0, sign)); -#endif -} - -/// \} -/// \anchor trigonometric -/// \name Trigonometric functions -/// \{ - -/// Compute sine and cosine simultaneously. -/// This returns the same results as sin() and cos() but is faster than calling each function -/// individually. -/// -/// This function is exact to rounding for all rounding modes. -/// \param arg function argument -/// \param sin variable to take sine of \a arg -/// \param cos variable to take cosine of \a arg -/// \exception FE_INVALID for signaling NaN or infinity -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline void sincos(half arg, half* sin, half* cos) -{ -#ifdef HALF_ARITHMETIC_TYPE - detail::internal_t f = detail::half2float(arg.data_); - *sin = half(detail::binary, detail::float2half(std::sin(f))); - *cos = half(detail::binary, detail::float2half(std::cos(f))); -#else - int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; - if(abs >= 0x7C00) - *sin = *cos = - half(detail::binary, (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - else if(!abs) - { - *sin = arg; - *cos = half(detail::binary, 0x3C00); - } - else if(abs < 0x2500) - { - *sin = half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); - } - else - { - if(half::round_style != std::round_to_nearest) - { - switch(abs) - { - case 0x48B7: - *sin = half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); - *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); - return; - case 0x598C: - *sin = half( - detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); - *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); - return; - case 0x6A64: - *sin = half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); - *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); - return; - case 0x6D8C: - *sin = half( - detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); - *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); - return; - } - } - std::pair sc = - detail::sincos(detail::angle_arg(abs, k), 28); - switch(k & 3) - { - case 1: sc = std::make_pair(sc.second, -sc.first); break; - case 2: sc = std::make_pair(-sc.first, -sc.second); break; - case 3: sc = std::make_pair(-sc.second, sc.first); break; - } - *sin = half(detail::binary, - detail::fixed2half( - (sc.first ^ -static_cast(sign)) + sign)); - *cos = half(detail::binary, - detail::fixed2half(sc.second)); - } -#endif -} - -/// Sine function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). -/// \param arg function argument -/// \return sine value of \a arg -/// \exception FE_INVALID for signaling NaN or infinity -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half sin(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::sin(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, k; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs < 0x2900) - return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - if(half::round_style != std::round_to_nearest) - switch(abs) - { - case 0x48B7: - return half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); - case 0x6A64: - return half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); - case 0x6D8C: - return half( - detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); - } - std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); - detail::uint32 sign = -static_cast(((k >> 1) & 1) ^ (arg.data_ >> 15)); - return half(detail::binary, - detail::fixed2half( - (((k & 1) ? sc.second : sc.first) ^ sign) - sign)); -#endif -} - -/// Cosine function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). -/// \param arg function argument -/// \return cosine value of \a arg -/// \exception FE_INVALID for signaling NaN or infinity -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half cos(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::cos(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, k; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs < 0x2500) - return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); - if(half::round_style != std::round_to_nearest && abs == 0x598C) - return half(detail::binary, detail::rounded(0x80FC, 1, 1)); - std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); - detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); - return half(detail::binary, - detail::fixed2half( - (((k & 1) ? sc.first : sc.second) ^ sign) - sign)); -#endif -} - -/// Tangent function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). -/// \param arg function argument -/// \return tangent value of \a arg -/// \exception FE_INVALID for signaling NaN or infinity -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half tan(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::tan(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = 13, k; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs < 0x2700) - return half(detail::binary, detail::rounded(arg.data_, 0, 1)); - if(half::round_style != std::round_to_nearest) - switch(abs) - { - case 0x658C: - return half( - detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x07E6, 1, 1)); - case 0x7330: - return half( - detail::binary, - detail::rounded((~arg.data_ & 0x8000) | 0x4B62, 1, 1)); - } - std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); - if(k & 1) - sc = std::make_pair(-sc.second, sc.first); - detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); - detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx; - for(; my < 0x80000000; my <<= 1, --exp) - ; - for(; mx < 0x80000000; mx <<= 1, ++exp) - ; - return half( - detail::binary, - detail::tangent_post(my, mx, exp, (signy ^ signx ^ arg.data_) & 0x8000)); -#endif -} - -/// Arc sine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). -/// \param arg function argument -/// \return arc sine value of \a arg -/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half asin(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::asin(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(!abs) - return arg; - if(abs >= 0x3C00) - return half(detail::binary, - (abs > 0x7C00) - ? detail::signal(arg.data_) - : (abs > 0x3C00) - ? detail::invalid() - : detail::rounded(sign | 0x3E48, 0, 1)); - if(abs < 0x2900) - return half(detail::binary, detail::rounded(arg.data_, 0, 1)); - if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) - return half(detail::binary, detail::rounded(arg.data_ + 1, 1, 1)); - std::pair sc = detail::atan2_args(abs); - detail::uint32 m = - detail::atan2(sc.first, sc.second, (half::round_style == std::round_to_nearest) ? 27 : 26); - return half(detail::binary, - detail::fixed2half(m, 14, sign)); -#endif -} - -/// Arc cosine function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). -/// \param arg function argument -/// \return arc cosine value of \a arg -/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half acos(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::acos(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; - if(!abs) - return half(detail::binary, detail::rounded(0x3E48, 0, 1)); - if(abs >= 0x3C00) - return half(detail::binary, - (abs > 0x7C00) - ? detail::signal(arg.data_) - : (abs > 0x3C00) - ? detail::invalid() - : sign ? detail::rounded(0x4248, 0, 1) : 0); - std::pair cs = detail::atan2_args(abs); - detail::uint32 m = detail::atan2(cs.second, cs.first, 28); - return half(detail::binary, - detail::fixed2half( - sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); -#endif -} - -/// Arc tangent function. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). -/// \param arg function argument -/// \return arc tangent value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half atan(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::atan(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs == 0x7C00) ? detail::rounded(sign | 0x3E48, 0, 1) - : detail::signal(arg.data_)); - if(abs <= 0x2700) - return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - int exp = (abs >> 10) + (abs <= 0x3FF); - detail::uint32 my = (abs & 0x3FF) | ((abs > 0x3FF) << 10); - detail::uint32 m = (exp > 15) - ? detail::atan2(my << 19, - 0x20000000 >> (exp - 15), - (half::round_style == std::round_to_nearest) ? 26 : 24) - : detail::atan2(my << (exp + 4), - 0x20000000, - (half::round_style == std::round_to_nearest) ? 30 : 28); - return half(detail::binary, - detail::fixed2half(m, 14, sign)); -#endif -} - -/// Arc tangent function. -/// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for -/// `std::round_to_nearest`, -/// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding -/// mode. -/// -/// **See also:** Documentation for -/// [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). -/// \param y numerator -/// \param x denominator -/// \return arc tangent value -/// \exception FE_INVALID if \a x or \a y is signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half atan2(half y, half x) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::atan2(detail::half2float(y.data_), - detail::half2float(x.data_)))); -#else - unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, - signy = y.data_ & 0x8000; - if(absx >= 0x7C00 || absy >= 0x7C00) - { - if(absx > 0x7C00 || absy > 0x7C00) - return half(detail::binary, detail::signal(x.data_, y.data_)); - if(absy == 0x7C00) - return half(detail::binary, - (absx < 0x7C00) - ? detail::rounded(signy | 0x3E48, 0, 1) - : signx - ? detail::rounded(signy | 0x40B6, 0, 1) - : detail::rounded(signy | 0x3A48, 0, 1)); - return (x.data_ == 0x7C00) - ? half(detail::binary, signy) - : half(detail::binary, - detail::rounded(signy | 0x4248, 0, 1)); - } - if(!absy) - return signx ? half(detail::binary, - detail::rounded(signy | 0x4248, 0, 1)) - : y; - if(!absx) - return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); - int d = (absy >> 10) + (absy <= 0x3FF) - (absx >> 10) - (absx <= 0x3FF); - if(d > (signx ? 18 : 12)) - return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); - if(signx && d < -11) - return half(detail::binary, detail::rounded(signy | 0x4248, 0, 1)); - if(!signx && d < ((half::round_style == std::round_toward_zero) ? -15 : -9)) - { - for(; absy < 0x400; absy <<= 1, --d) - ; - detail::uint32 mx = ((absx << 1) & 0x7FF) | 0x800, my = ((absy << 1) & 0x7FF) | 0x800; - int i = my < mx; - d -= i; - if(d < -25) - return half(detail::binary, detail::underflow(signy)); - my <<= 11 + i; - return half(detail::binary, - detail::fixed2half( - my / mx, d + 14, signy, my % mx != 0)); - } - detail::uint32 m = detail::atan2( - ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << (19 + ((d < 0) ? d : (d > 0) ? 0 : -1)), - ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << (19 - ((d > 0) ? d : (d < 0) ? 0 : 1))); - return half(detail::binary, - detail::fixed2half( - signx ? (0xC90FDAA2 - m) : m, 15, signy, signx)); -#endif -} - -/// \} -/// \anchor hyperbolic -/// \name Hyperbolic functions -/// \{ - -/// Hyperbolic sine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). -/// \param arg function argument -/// \return hyperbolic sine value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half sinh(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::sinh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp; - if(!abs || abs >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - if(abs <= 0x2900) - return half(detail::binary, detail::rounded(arg.data_, 0, 1)); - std::pair mm = - detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 29 : 27); - detail::uint32 m = mm.first - mm.second; - for(exp += 13; m < 0x80000000 && exp; m <<= 1, --exp) - ; - unsigned int sign = arg.data_ & 0x8000; - if(exp > 29) - return half(detail::binary, detail::overflow(sign)); - return half(detail::binary, - detail::fixed2half(m, exp, sign)); -#endif -} - -/// Hyperbolic cosine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). -/// \param arg function argument -/// \return hyperbolic cosine value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half cosh(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::cosh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x7C00) - return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) : 0x7C00); - std::pair mm = - detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 23 : 26); - detail::uint32 m = mm.first + mm.second, i = (~m & 0xFFFFFFFF) >> 31; - m = (m >> i) | (m & i) | 0x80000000; - if((exp += 13 + i) > 29) - return half(detail::binary, detail::overflow()); - return half(detail::binary, - detail::fixed2half(m, exp)); -#endif -} - -/// Hyperbolic tangent. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). -/// \param arg function argument -/// \return hyperbolic tangent value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half tanh(half arg) -{ -#ifdef HALF_ARITHMETIC_TYPE - return half(detail::binary, - detail::float2half( - std::tanh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp; - if(!abs) - return arg; - if(abs >= 0x7C00) - return half(detail::binary, - (abs > 0x7C00) ? detail::signal(arg.data_) : (arg.data_ - 0x4000)); - if(abs >= 0x4500) - return half(detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); - if(abs < 0x2700) - return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - if(half::round_style != std::round_to_nearest && abs == 0x2D3F) - return half(detail::binary, detail::rounded(arg.data_ - 3, 0, 1)); - std::pair mm = detail::hyperbolic_args(abs, exp, 27); - detail::uint32 my = mm.first - mm.second - (half::round_style != std::round_to_nearest), - mx = mm.first + mm.second, i = (~mx & 0xFFFFFFFF) >> 31; - for(exp = 13; my < 0x80000000; my <<= 1, --exp) - ; - mx = (mx >> i) | 0x80000000; - return half(detail::binary, - detail::tangent_post(my, mx, exp - i, arg.data_ & 0x8000)); -#endif -} - -/// Hyperbolic area sine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). -/// \param arg function argument -/// \return area sine value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half asinh(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::asinh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if(!abs || abs >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - if(abs <= 0x2900) - return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); - if(half::round_style != std::round_to_nearest) - switch(abs) - { - case 0x32D4: - return half(detail::binary, - detail::rounded(arg.data_ - 13, 1, 1)); - case 0x3B5B: - return half(detail::binary, - detail::rounded(arg.data_ - 197, 1, 1)); - } - return half(detail::binary, detail::area(arg.data_)); -#endif -} - -/// Hyperbolic area cosine. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). -/// \param arg function argument -/// \return area cosine value of \a arg -/// \exception FE_INVALID for signaling NaN or arguments <1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half acosh(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::acosh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if((arg.data_ & 0x8000) || abs < 0x3C00) - return half(detail::binary, - (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs == 0x3C00) - return half(detail::binary, 0); - if(arg.data_ >= 0x7C00) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - return half(detail::binary, detail::area(arg.data_)); -#endif -} - -/// Hyperbolic area tangent. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). -/// \param arg function argument -/// \return area tangent value of \a arg -/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 -/// \exception FE_DIVBYZERO for +/-1 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half atanh(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::atanh(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF, exp = 0; - if(!abs) - return arg; - if(abs >= 0x3C00) - return half(detail::binary, - (abs == 0x3C00) - ? detail::pole(arg.data_ & 0x8000) - : (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); - if(abs < 0x2700) - return half(detail::binary, detail::rounded(arg.data_, 0, 1)); - detail::uint32 m = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) - << ((abs >> 10) + (abs <= 0x3FF) + 6), - my = 0x80000000 + m, mx = 0x80000000 - m; - for(; mx < 0x80000000; mx <<= 1, ++exp) - ; - int i = my >= mx, s; - return half(detail::binary, - detail::log2_post( - detail::log2((detail::divide64(my >> i, mx, s) + 1) >> 1, 27) + 0x10, - exp + i - 1, - 16, - arg.data_ & 0x8000)); -#endif -} - -/// \} -/// \anchor special -/// \name Error and gamma functions -/// \{ - -/// Error function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% -/// of inputs. -/// -/// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). -/// \param arg function argument -/// \return error function value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half erf(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::erf(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF; - if(!abs || abs >= 0x7C00) - return (abs >= 0x7C00) - ? half(detail::binary, - (abs == 0x7C00) ? (arg.data_ - 0x4000) : detail::signal(arg.data_)) - : arg; - if(abs >= 0x4200) - return half(detail::binary, - detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); - return half(detail::binary, detail::erf(arg.data_)); -#endif -} - -/// Complementary error function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% -/// of inputs. -/// -/// **See also:** Documentation for -/// [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). -/// \param arg function argument -/// \return 1 minus error function value of \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half erfc(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::erfc(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(abs >= 0x7C00) - return (abs >= 0x7C00) - ? half(detail::binary, (abs == 0x7C00) ? (sign >> 1) : detail::signal(arg.data_)) - : arg; - if(!abs) - return half(detail::binary, 0x3C00); - if(abs >= 0x4400) - return half( - detail::binary, - detail::rounded((sign >> 1) - (sign >> 15), sign >> 15, 1)); - return half(detail::binary, detail::erf(arg.data_)); -#endif -} - -/// Natural logarithm of gamma function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in -/// ~0.025% of inputs. -/// -/// **See also:** Documentation for -/// [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). -/// \param arg function argument -/// \return natural logarith of gamma function for \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_DIVBYZERO for 0 or negative integer arguments -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half lgamma(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::lgamma(detail::half2float(arg.data_)))); -#else - int abs = arg.data_ & 0x7FFF; - if(abs >= 0x7C00) - return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); - if(!abs || arg.data_ >= 0xE400 || - (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) - return half(detail::binary, detail::pole()); - if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) - return half(detail::binary, 0); - return half(detail::binary, detail::gamma(arg.data_)); -#endif -} - -/// Gamma function. -/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in -/// <0.25% of inputs. -/// -/// **See also:** Documentation for -/// [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). -/// \param arg function argument -/// \return gamma function value of \a arg -/// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments -/// \exception FE_DIVBYZERO for 0 -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half tgamma(half arg) -{ -#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH - return half(detail::binary, - detail::float2half( - std::tgamma(detail::half2float(arg.data_)))); -#else - unsigned int abs = arg.data_ & 0x7FFF; - if(!abs) - return half(detail::binary, detail::pole(arg.data_)); - if(abs >= 0x7C00) - return (arg.data_ == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); - if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) - return half(detail::binary, detail::invalid()); - if(arg.data_ >= 0xCA80) - return half( - detail::binary, - detail::underflow((1 - ((abs >> (25 - (abs >> 10))) & 1)) << 15)); - if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) - return half(detail::binary, detail::overflow()); - if(arg.data_ == 0x3C00) - return arg; - return half(detail::binary, detail::gamma(arg.data_)); -#endif -} - -/// \} -/// \anchor rounding -/// \name Rounding -/// \{ - -/// Nearest integer not less than half value. -/// **See also:** Documentation for -/// [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). -/// \param arg half to round -/// \return nearest integer not less than \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half ceil(half arg) -{ - return half(detail::binary, - detail::integral(arg.data_)); -} - -/// Nearest integer not greater than half value. -/// **See also:** Documentation for -/// [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). -/// \param arg half to round -/// \return nearest integer not greater than \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half floor(half arg) -{ - return half(detail::binary, - detail::integral(arg.data_)); -} - -/// Nearest integer not greater in magnitude than half value. -/// **See also:** Documentation for -/// [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). -/// \param arg half to round -/// \return nearest integer not greater in magnitude than \a arg -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half trunc(half arg) -{ - return half(detail::binary, detail::integral(arg.data_)); -} - -/// Nearest integer. -/// **See also:** Documentation for -/// [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). -/// \param arg half to round -/// \return nearest integer, rounded away from zero in half-way cases -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half round(half arg) -{ - return half(detail::binary, detail::integral(arg.data_)); -} - -/// Nearest integer. -/// **See also:** Documentation for -/// [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). -/// \param arg half to round -/// \return nearest integer, rounded away from zero in half-way cases -/// \exception FE_INVALID if value is not representable as `long` -inline long lround(half arg) -{ - return detail::half2int(arg.data_); -} - -/// Nearest integer using half's internal rounding mode. -/// **See also:** Documentation for -/// [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). -/// \param arg half expression to round -/// \return nearest integer using default rounding mode -/// \exception FE_INVALID for signaling NaN -/// \exception FE_INEXACT if value had to be rounded -inline half rint(half arg) -{ - return half(detail::binary, detail::integral(arg.data_)); -} - -/// Nearest integer using half's internal rounding mode. -/// **See also:** Documentation for -/// [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). -/// \param arg half expression to round -/// \return nearest integer using default rounding mode -/// \exception FE_INVALID if value is not representable as `long` -/// \exception FE_INEXACT if value had to be rounded -inline long lrint(half arg) -{ - return detail::half2int(arg.data_); -} - -/// Nearest integer using half's internal rounding mode. -/// **See also:** Documentation for -/// [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). -/// \param arg half expression to round -/// \return nearest integer using default rounding mode -/// \exception FE_INVALID for signaling NaN -inline half nearbyint(half arg) -{ - return half(detail::binary, detail::integral(arg.data_)); -} -#if HALF_ENABLE_CPP11_LONG_LONG -/// Nearest integer. -/// **See also:** Documentation for -/// [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). -/// \param arg half to round -/// \return nearest integer, rounded away from zero in half-way cases -/// \exception FE_INVALID if value is not representable as `long long` -inline long long llround(half arg) -{ - return detail::half2int(arg.data_); -} - -/// Nearest integer using half's internal rounding mode. -/// **See also:** Documentation for -/// [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). -/// \param arg half expression to round -/// \return nearest integer using default rounding mode -/// \exception FE_INVALID if value is not representable as `long long` -/// \exception FE_INEXACT if value had to be rounded -inline long long llrint(half arg) -{ - return detail::half2int(arg.data_); -} -#endif - -/// \} -/// \anchor float -/// \name Floating point manipulation -/// \{ - -/// Decompress floating-point number. -/// **See also:** Documentation for -/// [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). -/// \param arg number to decompress -/// \param exp address to store exponent at -/// \return significant in range [0.5, 1) -/// \exception FE_INVALID for signaling NaN -inline half frexp(half arg, int* exp) -{ - *exp = 0; - unsigned int abs = arg.data_ & 0x7FFF; - if(abs >= 0x7C00 || !abs) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - for(; abs < 0x400; abs <<= 1, --*exp) - ; - *exp += (abs >> 10) - 14; - return half(detail::binary, (arg.data_ & 0x8000) | 0x3800 | (abs & 0x3FF)); -} - -/// Multiply by power of two. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). -/// \param arg number to modify -/// \param exp power of two to multiply with -/// \return \a arg multplied by 2 raised to \a exp -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half scalbln(half arg, long exp) -{ - unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; - if(abs >= 0x7C00 || !abs) - return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; - for(; abs < 0x400; abs <<= 1, --exp) - ; - exp += abs >> 10; - if(exp > 30) - return half(detail::binary, detail::overflow(sign)); - else if(exp < -10) - return half(detail::binary, detail::underflow(sign)); - else if(exp > 0) - return half(detail::binary, sign | (exp << 10) | (abs & 0x3FF)); - unsigned int m = (abs & 0x3FF) | 0x400; - return half(detail::binary, - detail::rounded( - sign | (m >> (1 - exp)), (m >> -exp) & 1, (m & ((1 << -exp) - 1)) != 0)); -} - -/// Multiply by power of two. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). -/// \param arg number to modify -/// \param exp power of two to multiply with -/// \return \a arg multplied by 2 raised to \a exp -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } - -/// Multiply by power of two. -/// This function is exact to rounding for all rounding modes. -/// -/// **See also:** Documentation for -/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). -/// \param arg number to modify -/// \param exp power of two to multiply with -/// \return \a arg multplied by 2 raised to \a exp -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } - -/// Extract integer and fractional parts. -/// **See also:** Documentation for -/// [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). -/// \param arg number to decompress -/// \param iptr address to store integer part at -/// \return fractional part -/// \exception FE_INVALID for signaling NaN -inline half modf(half arg, half* iptr) -{ - unsigned int abs = arg.data_ & 0x7FFF; - if(abs > 0x7C00) - { - arg = half(detail::binary, detail::signal(arg.data_)); - return *iptr = arg, arg; - } - if(abs >= 0x6400) - return *iptr = arg, half(detail::binary, arg.data_ & 0x8000); - if(abs < 0x3C00) - return iptr->data_ = arg.data_ & 0x8000, arg; - unsigned int exp = abs >> 10, mask = (1 << (25 - exp)) - 1, m = arg.data_ & mask; - iptr->data_ = arg.data_ & ~mask; - if(!m) - return half(detail::binary, arg.data_ & 0x8000); - for(; m < 0x400; m <<= 1, --exp) - ; - return half(detail::binary, (arg.data_ & 0x8000) | (exp << 10) | (m & 0x3FF)); -} - -/// Extract exponent. -/// **See also:** Documentation for -/// [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). -/// \param arg number to query -/// \return floating-point exponent -/// \retval FP_ILOGB0 for zero -/// \retval FP_ILOGBNAN for NaN -/// \retval INT_MAX for infinity -/// \exception FE_INVALID for 0 or infinite values -inline int ilogb(half arg) -{ - int abs = arg.data_ & 0x7FFF, exp; - if(!abs || abs >= 0x7C00) - { - detail::raise(FE_INVALID); - return !abs ? FP_ILOGB0 : (abs == 0x7C00) ? INT_MAX : FP_ILOGBNAN; - } - for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) - ; - return exp; -} - -/// Extract exponent. -/// **See also:** Documentation for -/// [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). -/// \param arg number to query -/// \return floating-point exponent -/// \exception FE_INVALID for signaling NaN -/// \exception FE_DIVBYZERO for 0 -inline half logb(half arg) -{ - int abs = arg.data_ & 0x7FFF, exp; - if(!abs) - return half(detail::binary, detail::pole(0x8000)); - if(abs >= 0x7C00) - return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); - for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) - ; - unsigned int value = static_cast(exp < 0) << 15; - if(exp) - { - unsigned int m = std::abs(exp) << 6; - for(exp = 18; m < 0x400; m <<= 1, --exp) - ; - value |= (exp << 10) + m; - } - return half(detail::binary, value); -} - -/// Next representable value. -/// **See also:** Documentation for -/// [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). -/// \param from value to compute next representable value for -/// \param to direction towards which to compute next value -/// \return next representable value after \a from in direction towards \a to -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW for infinite result from finite argument -/// \exception FE_UNDERFLOW for subnormal result -inline half nextafter(half from, half to) -{ - int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; - if(fabs > 0x7C00 || tabs > 0x7C00) - return half(detail::binary, detail::signal(from.data_, to.data_)); - if(from.data_ == to.data_ || !(fabs | tabs)) - return to; - if(!fabs) - { - detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); - return half(detail::binary, (to.data_ & 0x8000) + 1); - } - unsigned int out = - from.data_ + - (((from.data_ >> 15) ^ - static_cast((from.data_ ^ (0x8000 | (0x8000 - (from.data_ >> 15)))) < - (to.data_ ^ (0x8000 | (0x8000 - (to.data_ >> 15)))))) - << 1) - - 1; - detail::raise(FE_OVERFLOW, fabs < 0x7C00 && (out & 0x7C00) == 0x7C00); - detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7C00) < 0x400); - return half(detail::binary, out); -} - -/// Next representable value. -/// **See also:** Documentation for -/// [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). -/// \param from value to compute next representable value for -/// \param to direction towards which to compute next value -/// \return next representable value after \a from in direction towards \a to -/// \exception FE_INVALID for signaling NaN -/// \exception FE_OVERFLOW for infinite result from finite argument -/// \exception FE_UNDERFLOW for subnormal result -inline half nexttoward(half from, long double to) -{ - int fabs = from.data_ & 0x7FFF; - if(fabs > 0x7C00) - return half(detail::binary, detail::signal(from.data_)); - long double lfrom = static_cast(from); - if(detail::builtin_isnan(to) || lfrom == to) - return half(static_cast(to)); - if(!fabs) - { - detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); - return half(detail::binary, (static_cast(detail::builtin_signbit(to)) << 15) + 1); - } - unsigned int out = - from.data_ + (((from.data_ >> 15) ^ static_cast(lfrom < to)) << 1) - 1; - detail::raise(FE_OVERFLOW, (out & 0x7FFF) == 0x7C00); - detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7FFF) < 0x400); - return half(detail::binary, out); -} - -/// Take sign. -/// **See also:** Documentation for -/// [std::copysign](https://en.cppreference.com/w/cpp/numeric/math/copysign). -/// \param x value to change sign for -/// \param y value to take sign from -/// \return value equal to \a x in magnitude and to \a y in sign -inline HALF_CONSTEXPR half copysign(half x, half y) -{ - return half(detail::binary, x.data_ ^ ((x.data_ ^ y.data_) & 0x8000)); -} - -/// \} -/// \anchor classification -/// \name Floating point classification -/// \{ - -/// Classify floating-point value. -/// **See also:** Documentation for -/// [std::fpclassify](https://en.cppreference.com/w/cpp/numeric/math/fpclassify). -/// \param arg number to classify -/// \retval FP_ZERO for positive and negative zero -/// \retval FP_SUBNORMAL for subnormal numbers -/// \retval FP_INFINITY for positive and negative infinity -/// \retval FP_NAN for NaNs -/// \retval FP_NORMAL for all other (normal) values -inline HALF_CONSTEXPR int fpclassify(half arg) -{ - return !(arg.data_ & 0x7FFF) - ? FP_ZERO - : ((arg.data_ & 0x7FFF) < 0x400) - ? FP_SUBNORMAL - : ((arg.data_ & 0x7FFF) < 0x7C00) - ? FP_NORMAL - : ((arg.data_ & 0x7FFF) == 0x7C00) ? FP_INFINITE : FP_NAN; -} - -/// Check if finite number. -/// **See also:** Documentation for -/// [std::isfinite](https://en.cppreference.com/w/cpp/numeric/math/isfinite). -/// \param arg number to check -/// \retval true if neither infinity nor NaN -/// \retval false else -inline HALF_CONSTEXPR bool isfinite(half arg) { return (arg.data_ & 0x7C00) != 0x7C00; } - -/// Check for infinity. -/// **See also:** Documentation for -/// [std::isinf](https://en.cppreference.com/w/cpp/numeric/math/isinf). -/// \param arg number to check -/// \retval true for positive or negative infinity -/// \retval false else -inline HALF_CONSTEXPR bool isinf(half arg) { return (arg.data_ & 0x7FFF) == 0x7C00; } - -/// Check for NaN. -/// **See also:** Documentation for -/// [std::isnan](https://en.cppreference.com/w/cpp/numeric/math/isnan). -/// \param arg number to check -/// \retval true for NaNs -/// \retval false else -inline HALF_CONSTEXPR bool isnan(half arg) { return (arg.data_ & 0x7FFF) > 0x7C00; } - -/// Check if normal number. -/// **See also:** Documentation for -/// [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). -/// \param arg number to check -/// \retval true if normal number -/// \retval false if either subnormal, zero, infinity or NaN -inline HALF_CONSTEXPR bool isnormal(half arg) -{ - return ((arg.data_ & 0x7C00) != 0) & ((arg.data_ & 0x7C00) != 0x7C00); -} - -/// Check sign. -/// **See also:** Documentation for -/// [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). -/// \param arg number to check -/// \retval true for negative number -/// \retval false for positive number -inline HALF_CONSTEXPR bool signbit(half arg) { return (arg.data_ & 0x8000) != 0; } - -/// \} -/// \anchor compfunc -/// \name Comparison -/// \{ - -/// Quiet comparison for greater than. -/// **See also:** Documentation for -/// [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). -/// \param x first operand -/// \param y second operand -/// \retval true if \a x greater than \a y -/// \retval false else -inline HALF_CONSTEXPR bool isgreater(half x, half y) -{ - return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && - !isnan(x) && !isnan(y); -} - -/// Quiet comparison for greater equal. -/// **See also:** Documentation for -/// [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). -/// \param x first operand -/// \param y second operand -/// \retval true if \a x greater equal \a y -/// \retval false else -inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) -{ - return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && - !isnan(x) && !isnan(y); -} - -/// Quiet comparison for less than. -/// **See also:** Documentation for -/// [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). -/// \param x first operand -/// \param y second operand -/// \retval true if \a x less than \a y -/// \retval false else -inline HALF_CONSTEXPR bool isless(half x, half y) -{ - return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && - !isnan(x) && !isnan(y); -} - -/// Quiet comparison for less equal. -/// **See also:** Documentation for -/// [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). -/// \param x first operand -/// \param y second operand -/// \retval true if \a x less equal \a y -/// \retval false else -inline HALF_CONSTEXPR bool islessequal(half x, half y) -{ - return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= - ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && - !isnan(x) && !isnan(y); -} - -/// Quiet comarison for less or greater. -/// **See also:** Documentation for -/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). -/// \param x first operand -/// \param y second operand -/// \retval true if either less or greater -/// \retval false else -inline HALF_CONSTEXPR bool islessgreater(half x, half y) -{ - return x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF) && !isnan(x) && !isnan(y); -} - -/// Quiet check if unordered. -/// **See also:** Documentation for -/// [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). -/// \param x first operand -/// \param y second operand -/// \retval true if unordered (one or two NaN operands) -/// \retval false else -inline HALF_CONSTEXPR bool isunordered(half x, half y) { return isnan(x) || isnan(y); } - -/// \} -/// \anchor casting -/// \name Casting -/// \{ - -/// Cast to or from half-precision floating-point number. -/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values -/// are converted -/// directly using the default rounding mode, without any roundtrip over `float` that a -/// `static_cast` would otherwise do. -/// -/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any -/// of the two types -/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) -/// results in a compiler -/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. -/// \tparam T destination type (half or built-in arithmetic type) -/// \tparam U source type (half or built-in arithmetic type) -/// \param arg value to cast -/// \return \a arg converted to destination type -/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -template -T half_cast(U arg) -{ - return detail::half_caster::cast(arg); -} - -/// Cast to or from half-precision floating-point number. -/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values -/// are converted -/// directly using the specified rounding mode, without any roundtrip over `float` that a -/// `static_cast` would otherwise do. -/// -/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any -/// of the two types -/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) -/// results in a compiler -/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. -/// \tparam T destination type (half or built-in arithmetic type) -/// \tparam R rounding mode to use. -/// \tparam U source type (half or built-in arithmetic type) -/// \param arg value to cast -/// \return \a arg converted to destination type -/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T -/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding -template -T half_cast(U arg) -{ - return detail::half_caster::cast(arg); -} -/// \} - -/// \} -/// \anchor errors -/// \name Error handling -/// \{ - -/// Clear exception flags. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). -/// \param excepts OR of exceptions to clear -/// \retval 0 all selected flags cleared successfully -inline int feclearexcept(int excepts) -{ - detail::errflags() &= ~excepts; - return 0; -} - -/// Test exception flags. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). -/// \param excepts OR of exceptions to test -/// \return OR of selected exceptions if raised -inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } - -/// Raise exception flags. -/// This raises the specified floating point exceptions and also invokes any additional automatic -/// exception handling as -/// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). -/// \param excepts OR of exceptions to raise -/// \retval 0 all selected exceptions raised successfully -inline int feraiseexcept(int excepts) -{ - detail::errflags() |= excepts; - detail::raise(excepts); - return 0; -} - -/// Save exception flags. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). -/// \param flagp adress to store flag state at -/// \param excepts OR of flags to save -/// \retval 0 for success -inline int fegetexceptflag(int* flagp, int excepts) -{ - *flagp = detail::errflags() & excepts; - return 0; -} - -/// Restore exception flags. -/// This only copies the specified exception state (including unset flags) without incurring any -/// additional exception handling. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// -/// **See also:** Documentation for -/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). -/// \param flagp adress to take flag state from -/// \param excepts OR of flags to restore -/// \retval 0 for success -inline int fesetexceptflag(const int* flagp, int excepts) -{ - detail::errflags() = (detail::errflags() | (*flagp & excepts)) & (*flagp | ~excepts); - return 0; -} - -/// Throw C++ exceptions based on set exception flags. -/// This function manually throws a corresponding C++ exception if one of the specified flags is -/// set, -/// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref -/// HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. -/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is -/// disabled, -/// but in that case manual flag management is the only way to raise flags. -/// \param excepts OR of exceptions to test -/// \param msg error message to use for exception description -/// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set -/// \throw std::overflow_error if `FE_OVERFLOW` is selected and set -/// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set -/// \throw std::range_error if `FE_INEXACT` is selected and set -inline void fethrowexcept(int excepts, const char* msg = "") -{ - excepts &= detail::errflags(); - if(excepts & (FE_INVALID | FE_DIVBYZERO)) - throw std::domain_error(msg); - if(excepts & FE_OVERFLOW) - throw std::overflow_error(msg); - if(excepts & FE_UNDERFLOW) - throw std::underflow_error(msg); - if(excepts & FE_INEXACT) - throw std::range_error(msg); -} -/// \} -} // namespace half_float - -#undef HALF_UNUSED_NOERR -#undef HALF_CONSTEXPR -#undef HALF_CONSTEXPR_CONST -#undef HALF_CONSTEXPR_NOERR -#undef HALF_NOEXCEPT -#undef HALF_NOTHROW -#undef HALF_THREAD_LOCAL -#undef HALF_TWOS_COMPLEMENT_INT -#ifdef HALF_POP_WARNINGS -#pragma warning(pop) -#undef HALF_POP_WARNINGS -#endif - -#endif diff --git a/host/CMakeLists.txt b/host/CMakeLists.txt index 26739efe34..30cc14d8ca 100644 --- a/host/CMakeLists.txt +++ b/host/CMakeLists.txt @@ -1,4 +1,2 @@ add_subdirectory(host_tensor) -add_subdirectory(online_compile) add_subdirectory(driver_offline) -add_subdirectory(driver_online) diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index 927975d449..fec11e99af 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -9,11 +9,10 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver ${PROJECT_SOURCE_DIR}/external/rocm/include - ${PROJECT_SOURCE_DIR}/external/half/include ) -set(CONV_FWD_DRIVER_OFFLINE_SOURCE conv_fwd_driver_offline.cpp) -set(CONV_BWD_DRIVER_OFFLINE_SOURCE conv_bwd_driver_offline.cpp) +set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) +set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) diff --git a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp similarity index 89% rename from host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 49e0223b33..7bd82bf6d5 100644 --- a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_xdlops_v2r3.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" template -void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( +void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( const InLengths& in_n_hi_wi_c_lengths, const WeiLengths& wei_k_y_x_c_lengths, const OutLengths& out_n_ho_wo_k_lengths, @@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); @@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 1 // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 @@ -215,7 +207,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx const auto in_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 @@ -223,7 +215,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 @@ -231,7 +223,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 - constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( + constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat @@ -251,15 +243,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; - constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_xdlops_v2r3< + float ave_time = driver_gemm_xdlops_v2r3< BlockSize, TInWei, TAcc, @@ -295,11 +287,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx Sequence<1, 3, 7, 0, 2, 4, 5, 6>, 6, GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(in_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), @@ -307,11 +299,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx wei_gemmk0_gemmm_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - out_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - in_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { @@ -319,16 +311,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx const auto K = out_n_ho_wo_k_lengths[I3]; const auto C = wei_k_y_x_c_lengths[I3]; - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Wo = out_n_ho_wo_k_lengths[I2]; const auto Y = wei_k_y_x_c_lengths[I1]; const auto X = wei_k_y_x_c_lengths[I2]; - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" diff --git a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp similarity index 88% rename from host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index ce4dd155f6..0ebf8571f4 100644 --- a/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_xdlops_v2r3.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" template -void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( +void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( const InLengths& in_n_hi_wi_c_lengths, const WeiLengths& wei_k_y_x_c_lengths, const OutLengths& out_n_ho_wo_k_lengths, @@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); @@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 0 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 @@ -187,7 +179,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k const auto in_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 @@ -195,7 +187,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 @@ -203,7 +195,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( + constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat @@ -223,15 +215,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 - constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_xdlops_v2r3< + float ave_time = driver_gemm_xdlops_v2r3< BlockSize, TInWei, TAcc, @@ -271,11 +263,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k #endif 7, GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(in_m0_m1_m2_n_grid_iterator_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), true // CAccessOrderMRepeatNRepeat >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), @@ -283,11 +275,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k out_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmn_grid_desc, - out_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - in_m0_m1_m2_n_grid_iterator_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_m1_m2_n_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { @@ -295,16 +287,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k const auto K = out_n_ho_wo_k_lengths[I3]; const auto C = wei_k_y_x_c_lengths[I3]; - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Wo = out_n_ho_wo_k_lengths[I2]; const auto Y = wei_k_y_x_c_lengths[I1]; const auto X = wei_k_y_x_c_lengths[I2]; - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp similarity index 84% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp index 24ba775309..e6554cf0fe 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_gemm_dlops_v1r2.hpp" +#include "driver_gemm_dlops_v1r2.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( +void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); @@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); #if 1 // cdata = 64, BlockSize = 256, 128x128x8 @@ -98,7 +89,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( in_right_pads); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = + constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -108,7 +99,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = + constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), @@ -116,7 +107,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -130,10 +121,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); - constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; const auto wei_gemmk_gemmm_grid_desc = descs[I0]; @@ -142,7 +133,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_dlops_v1r2< + float ave_time = driver_gemm_dlops_v1r2< BlockSize, TInWei, TAcc, @@ -180,26 +171,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder 5, // CThreadTransferSrcDstVectorDim GemmCThreadTransferDstScalarPerVector_N11, - decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), - decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>( + decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), + decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>( static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc, - wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks, - in_gemmk_gemmn0_gemmn1_grid_iterator_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, - wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks, - in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks, + wei_gemmk_gemmm0_gemmn1_grid_step_hacks, + in_gemmk_gemmn0_gemmn1_grid_step_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, + wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks, + in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks, nrepeat); - float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp similarity index 94% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp index b6b1cc8969..4a9d01081c 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -1,7 +1,7 @@ #include #include "device.hpp" #include "host_tensor.hpp" -#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" +#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( +void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -48,12 +48,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); #if 0 constexpr index_t BlockSize = 256; @@ -212,9 +209,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw for(index_t i = 0; i < 5; ++i) { #if 0 - float ave_time = launch_kernel_dynamic_gemm_xdlops_v1 + float ave_time = launch_kernel_gemm_xdlops_v1 #else - float ave_time = launch_kernel_dynamic_gemm_xdlops_v2 + float ave_time = launch_kernel_gemm_xdlops_v2 #endif -void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( +void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( const InLengths& in_n_hi_wi_c_lengths, const WeiLengths& wei_k_y_x_c_lengths, const OutLengths& out_n_ho_wo_k_lengths, @@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); @@ -49,14 +44,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); -#if 1 +#if 0 // [M, N, K0, K1] = [128, 128, 8, 1] for fp32 // cdata = 64, BlockSize = 256 constexpr index_t BlockSize = 256; @@ -163,7 +155,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1 @@ -173,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 - constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks = + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1 @@ -183,7 +175,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1 Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0 Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10 Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11 @@ -197,15 +189,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10 Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11 - constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; - constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_dlops_v1r3< + float ave_time = driver_gemm_dlops_v1r3< BlockSize, TInWei, TAcc, @@ -239,22 +231,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder 5, // CThreadTransferSrcDstVectorDim GemmCThreadTransferDstScalarPerVector_N11, - decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks), - decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), - decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>( + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>( static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), in_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks, - wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, - in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks, - wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { @@ -262,16 +254,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw const auto K = out_n_ho_wo_k_lengths[I3]; const auto C = wei_k_y_x_c_lengths[I3]; - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Wo = out_n_ho_wo_k_lengths[I2]; const auto Y = wei_k_y_x_c_lengths[I1]; const auto X = wei_k_y_x_c_lengths[I2]; - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp similarity index 84% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index b56cbc0335..695ffeeb36 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_gemm_xdlops_v2r3.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( +void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); @@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); #if 1 // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 @@ -101,12 +92,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -114,7 +105,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + constexpr auto out_m0_m1_m2_n_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -132,15 +123,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_xdlops_v2r3< + float ave_time = driver_gemm_xdlops_v2r3< BlockSize, TInWei, TAcc, @@ -176,26 +167,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk Sequence<3, 0, 1, 2, 7, 5, 4, 6>, 7, GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), wei_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); - float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp similarity index 84% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp index 10284b48f3..141a326574 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_xdlops_v2r2.hpp" +#include "driver_gemm_xdlops_v2r2.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( +void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( const InLengths& in_n_hi_wi_c_lengths, const WeiLengths& wei_k_y_x_c_lengths, const OutLengths& out_n_ho_wo_k_lengths, @@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); @@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 1 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 @@ -129,12 +121,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -142,7 +134,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + constexpr auto out_m0_m1_m2_n_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -152,15 +144,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_xdlops_v2r2< + float ave_time = driver_gemm_xdlops_v2r2< BlockSize, TInWei, TAcc, @@ -195,22 +187,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh Sequence<2, 3, 0, 1>, 2, GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>( + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), wei_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { @@ -218,9 +210,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh const auto K = out_n_ho_wo_k_lengths[I3]; const auto C = wei_k_y_x_c_lengths[I3]; - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Wo = out_n_ho_wo_k_lengths[I2]; diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp similarity index 90% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp index f2a30fb525..692751bfb3 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_xdlops_v2r3.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( +void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( const InLengths& in_n_hi_wi_c_lengths, const WeiLengths& wei_k_y_x_c_lengths, const OutLengths& out_n_ho_wo_k_lengths, @@ -49,12 +49,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 1 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 @@ -185,12 +182,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), make_tuple( Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), @@ -198,7 +195,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + constexpr auto out_m0_m1_m2_n_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -216,15 +213,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{})); - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_xdlops_v2r3< + float ave_time = driver_gemm_xdlops_v2r3< BlockSize, TInWei, TAcc, @@ -259,11 +256,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh Sequence<2, 3, 0, 1, 7, 5, 4, 6>, 6, GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), @@ -271,11 +268,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh wei_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp similarity index 89% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 601878c347..7067291c8a 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_xdlops_v2r3.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( +void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( const InLengths& in_n_hi_wi_c_lengths, const WeiLengths& wei_k_y_x_c_lengths, const OutLengths& out_n_ho_wo_k_lengths, @@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); @@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 0 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 @@ -241,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto in_gemmk0_gemmm_gemmk1_grid_iterator_hacks = + constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 @@ -249,7 +241,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 @@ -257,7 +249,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + constexpr auto out_m0_m1_m2_n_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves @@ -275,15 +267,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_gemm_xdlops_v2r3< + float ave_time = driver_gemm_xdlops_v2r3< BlockSize, TInWei, TAcc, @@ -319,11 +311,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh Sequence<2, 3, 0, 1, 7, 5, 4, 6>, 7, GemmCThreadTransferDstScalarPerVector, - decltype(in_gemmk0_gemmm_gemmk1_grid_iterator_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), - decltype(out_m0_m1_m2_n_grid_iterator_hacks), - decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), @@ -331,11 +323,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh in_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, - in_gemmk0_gemmm_gemmk1_grid_iterator_hacks, - wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, - out_m0_m1_m2_n_grid_iterator_hacks, - in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); { @@ -343,16 +335,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh const auto K = out_n_ho_wo_k_lengths[I3]; const auto C = wei_k_y_x_c_lengths[I3]; - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Wo = out_n_ho_wo_k_lengths[I2]; const auto Y = wei_k_y_x_c_lengths[I1]; const auto X = wei_k_y_x_c_lengths[I2]; - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp similarity index 91% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp index ca0d47c33a..b5e5f91d59 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -1,8 +1,8 @@ #include #include "device.hpp" #include "host_tensor.hpp" -#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" +#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" +#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( +void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -26,7 +26,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( const Tensor& in_n_c_hi_wi, const Tensor& wei_k_c_y_x, Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) + ck::index_t /* nrepeat */) { using namespace ck; @@ -85,12 +85,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); - const auto in_n_c0_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); - const auto wei_k_c0_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); + const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi)); + const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X)); const auto out_n_k0_ho_wo_k1_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); #if 1 // cdata = 64, BlockSize = 64, 16x8x32x4 diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp similarity index 88% rename from host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp index 8fb276b464..e1b7c5486c 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -3,7 +3,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_contraction_dlops_v1r2.hpp" +#include "driver_contraction_dlops_v1r2.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( +void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - const auto in_desc_n_c_hi_wi = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_desc_k_c_y_x = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_desc_n_k_ho_wo = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); #if 1 // [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32 @@ -133,7 +130,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_grid_iterator_hacks = + constexpr auto wei_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 @@ -145,7 +142,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 - constexpr auto in_grid_iterator_hacks = make_tuple( + constexpr auto in_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 @@ -157,7 +154,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 - constexpr auto out_grid_iterator_hacks = make_tuple( + constexpr auto out_grid_step_hacks = make_tuple( make_tuple( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 @@ -173,14 +170,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1 - constexpr auto wei_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; + constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; - constexpr auto in_grid_move_slice_window_iterator_hacks = + constexpr auto in_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{}; for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_dynamic_contraction_dlops_v1r2< + float ave_time = driver_contraction_dlops_v1r2< BlockSize, TInWei, TAcc, @@ -214,26 +211,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder 5, // CThreadTransferSrcDstVectorDim CThreadTransferDstScalarPerVector_BN1, - decltype(wei_grid_iterator_hacks), - decltype(in_grid_iterator_hacks), - decltype(out_grid_iterator_hacks), - decltype(wei_grid_move_slice_window_iterator_hacks), - decltype(in_grid_move_slice_window_iterator_hacks)>( + decltype(wei_grid_step_hacks), + decltype(in_grid_step_hacks), + decltype(out_grid_step_hacks), + decltype(wei_grid_move_slice_window_step_hacks), + decltype(in_grid_move_slice_window_step_hacks)>( static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), wei_grid_desc_gk0_gm0_gm1_gk1, in_grid_desc_gk0_gn0_gn1_gk1, out_grid_desc_gm0_gm1_gn0_gn1, - wei_grid_iterator_hacks, - in_grid_iterator_hacks, - out_grid_iterator_hacks, - wei_grid_move_slice_window_iterator_hacks, - in_grid_move_slice_window_iterator_hacks, + wei_grid_step_hacks, + in_grid_step_hacks, + out_grid_step_hacks, + wei_grid_move_slice_window_step_hacks, + in_grid_move_slice_window_step_hacks, nrepeat); - float perf = (float)calculate_convolution_flops( - in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo) / + float perf = static_cast(calculate_convolution_flops( + in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; diff --git a/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp b/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp similarity index 84% rename from host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp rename to host/driver_offline/include/driver_contraction_dlops_v1r2.hpp index 2f175962c1..d207728a2e 100644 --- a/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp +++ b/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp @@ -1,10 +1,10 @@ -#ifndef DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP -#define DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP +#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP +#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_contraction_dlops_v1r2.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_contraction_dlops_v1r2.hpp" template + typename AGridStepHacks, + typename BGridStepHacks, + typename CGridStepHacks, + typename AGridMoveSliceWindowStepHacks, + typename BGridMoveSliceWindowStepHacks> __host__ float -driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, - const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - ck::index_t nrepeat) +driver_contraction_dlops_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) { using namespace ck; @@ -70,7 +70,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, // GEMM using GridwiseContraction = - GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< BlockSize, FloatAB, FloatAcc, @@ -104,11 +104,11 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); @@ -116,7 +116,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1)) { throw std::runtime_error("wrong! " - "GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_" + "GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_" "GM0_GM1_GN0_GN1 has invalid setting"); } @@ -178,7 +178,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, if(has_main_k_block_loop && has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_dlops_v1r2< + const auto kernel = kernel_contraction_dlops_v1r2< GridwiseContraction, FloatAB, FloatC, @@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -205,7 +204,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, } else if(has_main_k_block_loop && !has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_dlops_v1r2< + const auto kernel = kernel_contraction_dlops_v1r2< GridwiseContraction, FloatAB, FloatC, @@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -232,7 +230,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, } else if(!has_main_k_block_loop && has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_dlops_v1r2< + const auto kernel = kernel_contraction_dlops_v1r2< GridwiseContraction, FloatAB, FloatC, @@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -259,7 +256,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, } else { - const auto kernel = kernel_dynamic_contraction_dlops_v1r2< + const auto kernel = kernel_contraction_dlops_v1r2< GridwiseContraction, FloatAB, FloatC, @@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, diff --git a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp similarity index 86% rename from host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp index 7c4b1043f3..efd4ce6a19 100644 --- a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -1,10 +1,10 @@ -#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP -#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP +#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP +#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_dlops_v2.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v2.hpp" #include "gridwise_operation_wrapper.hpp" template - __host__ void Run(const ck::DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const ck::DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const ck::DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + __host__ void Run(const ck::TensorDescriptor& wei_k_c_y_x_global_desc, + const ck::TensorDescriptor& in_n_c_hi_wi_global_desc, + const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -82,14 +82,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad const auto InRightPadW = in_right_pads[I1]; // weight tensor - const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + const auto wei_e_k_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, make_tuple(make_pass_through_transform(N), make_pass_through_transform(C), @@ -98,7 +98,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_hip_wip_global_desc, make_tuple( make_pass_through_transform(N), @@ -108,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor( in_n_c_y_ho_x_wo_global_desc, make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_pass_through_transform(N), @@ -118,8 +118,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // output tensor - const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), + const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), make_tuple(make_merge_transform(make_tuple(K0, K1)), make_pass_through_transform(N), make_pass_through_transform(Ho), @@ -136,13 +136,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad } // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_e_k_global_iterator_hacks = + constexpr auto a_e_k_global_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; - constexpr auto b_e_n_ho_wo_global_iterator_hacks = + constexpr auto b_e_n_ho_wo_global_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, @@ -152,12 +152,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack for NKHW format - constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -169,7 +169,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad #if 1 // GEMM - using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3< + using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3< BlockSize, FloatAB, FloatAcc, @@ -202,11 +202,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad Sequence<0, 2, 3, 1>, 0, CThreadTransferDstScalarPerVector_W, - decltype(a_e_k_global_iterator_hacks), - decltype(b_e_n_ho_wo_global_iterator_hacks), - decltype(c_k_n_ho_wo_global_tensor_iterator_hacks), - decltype(a_e_k_global_move_slice_window_iterator_hack), - decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>; + decltype(a_e_k_global_step_hacks), + decltype(b_e_n_ho_wo_global_step_hacks), + decltype(c_k_n_ho_wo_global_tensor_step_hacks), + decltype(a_e_k_global_move_slice_window_step_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; @@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -338,10 +334,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad float ave_time = timer.GetElapsedTime() / nrepeat; - float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k0_ho_wo_k1_global_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; + float perf = + static_cast(calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k0_ho_wo_k1_global_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; diff --git a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp similarity index 86% rename from host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp rename to host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp index b7f8e6039c..70f73cbf4a 100644 --- a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp +++ b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp @@ -1,10 +1,10 @@ -#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP -#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP +#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP +#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP #include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_dlops_v2.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v2.hpp" #include "gridwise_operation_wrapper.hpp" template - __host__ void Run(const ck::DynamicTensorDescriptor& wei_k_c_y_x_global_desc, - const ck::DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, - const ck::DynamicTensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + __host__ void Run(const ck::TensorDescriptor& wei_k_c_y_x_global_desc, + const ck::TensorDescriptor& in_n_c_hi_wi_global_desc, + const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -93,14 +93,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp << std::endl; // weight tensor - const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + const auto wei_e_k_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, make_tuple(make_pass_through_transform(N), make_pass_through_transform(C), @@ -109,7 +109,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_hip_wip_global_desc, make_tuple( make_pass_through_transform(N), @@ -119,7 +119,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor( in_n_c_y_ho_x_wo_global_desc, make_tuple(make_merge_transform(make_tuple(C, Y, X)), make_pass_through_transform(N), @@ -129,8 +129,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // output tensor - const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)), + const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), make_tuple(make_merge_transform(make_tuple(K0, K1)), make_pass_through_transform(N), make_pad_transform(Ho, 0, OutRightPadH), @@ -149,13 +149,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp } // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto a_e_k_global_iterator_hacks = + constexpr auto a_e_k_global_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; - constexpr auto b_e_n_ho_wo_global_iterator_hacks = + constexpr auto b_e_n_ho_wo_global_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, @@ -165,12 +165,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor // hack for NKHW format - constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, @@ -181,7 +181,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp Sequence<0, 0, 0, 0, 0>{})); // GEMM - using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3< + using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3< BlockSize, FloatAB, FloatAcc, @@ -214,11 +214,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp Sequence<0, 2, 3, 1>, 0, CThreadTransferDstScalarPerVector_W, - decltype(a_e_k_global_iterator_hacks), - decltype(b_e_n_ho_wo_global_iterator_hacks), - decltype(c_k_n_ho_wo_global_tensor_iterator_hacks), - decltype(a_e_k_global_move_slice_window_iterator_hack), - decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>; + decltype(a_e_k_global_step_hacks), + decltype(b_e_n_ho_wo_global_step_hacks), + decltype(c_k_n_ho_wo_global_tensor_step_hacks), + decltype(a_e_k_global_move_slice_window_step_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; @@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -354,10 +350,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp float ave_time = timer.GetElapsedTime() / nrepeat; - float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, - wei_k_c_y_x_global_desc, - out_n_k0_ho_wo_k1_global_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; + float perf = + static_cast(calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k0_ho_wo_k1_global_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; diff --git a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp deleted file mode 100644 index 0ebc68b48a..0000000000 --- a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp +++ /dev/null @@ -1,415 +0,0 @@ -#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R2 -#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R2 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_dlops_v1r2.hpp" - -template -__host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AKMGridDesc& a_k_m_grid_desc, - const BKNGridDesc& b_k_n_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseGemm = - GridwiseDynamicGemmDlops_km_kn_mn_v1r2; - - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_k_n_grid_desc.GetLength(I1); - const auto K = a_k_m_grid_desc.GetLength(I0); - - if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error( - "wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r2 has invalid setting"); - } - - const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); - const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); - - using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); - using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); - - // c_m0_m10_m11_n0_n10_n11_grid_desc - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); - - // c_blockid_to_m0_n0_block_cluster_adaptor - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); - - const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); - - { - std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " - << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " - << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; - } - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc)); - DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc)); - DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); - DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( - sizeof(CBlockIdToM0N0BlockClusterAdaptor)); - - a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc); - b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc); - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( - &c_blockid_to_m0_n0_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r2, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - - return ave_time; -#endif -} -#endif diff --git a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp deleted file mode 100644 index d075eac822..0000000000 --- a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp +++ /dev/null @@ -1,411 +0,0 @@ -#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R3 -#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R3 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_dlops_v1r3.hpp" - -template -__host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseGemm = - GridwiseDynamicGemmDlops_km_kn_mn_v1r3; - - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error( - "wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r3 has invalid setting"); - } - - const auto a_k0_m0_m1_k1_grid_desc = - GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); - const auto b_k0_n0_n1_k1_grid_desc = - GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); - - using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); - using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); - - // c_m0_m10_m11_n0_n10_n11_grid_desc - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = - GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); - - using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); - - // c_blockid_to_m0_n0_block_cluster_adaptor - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); - - const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); - - { - std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " - << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " - << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " - << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; - } - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - else - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_blockid_to_m0_n0_block_cluster_adaptor); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc)); - DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc)); - DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); - DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( - sizeof(CBlockIdToM0N0BlockClusterAdaptor)); - - a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc); - b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc); - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( - &c_blockid_to_m0_n0_block_cluster_adaptor); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - else - { - const auto kernel = - kernel_dynamic_gemm_dlops_v1r3, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); - } - - return ave_time; -#endif -} -#endif diff --git a/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp deleted file mode 100644 index 481d08188d..0000000000 --- a/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp +++ /dev/null @@ -1,196 +0,0 @@ -#ifndef DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 -#define DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" - -template -__host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - ck::index_t nrepeat) - -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - using GridwiseGemm = - GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; - - { - std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " - << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " - << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) - << "}" << std::endl; - - std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " - << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) - { - throw std::runtime_error( - "wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); - } - - const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); - - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); - - const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); - - using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); - - const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); - - const auto kernel = kernel_dynamic_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t>; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, - c_block_cluster_adaptor); - -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); - DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); - DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); - DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); - - a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); - b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); - c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); - c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - - float ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); -#endif - return ave_time; -} -#endif diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp b/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp new file mode 100644 index 0000000000..bf5f7f1c0f --- /dev/null +++ b/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp @@ -0,0 +1,413 @@ +#ifndef DRIVER_GEMM_DLOPS_V1R2 +#define DRIVER_GEMM_DLOPS_V1R2 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v1r2.hpp" + +template +__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2; + + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting"); + } + + const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + + using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); + using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); + + { + std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " + << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " + << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc)); + DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc); + b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + + return ave_time; +#endif +} +#endif diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp b/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp new file mode 100644 index 0000000000..4470918820 --- /dev/null +++ b/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp @@ -0,0 +1,418 @@ +#ifndef DRIVER_GEMM_DLOPS_V1R3 +#define DRIVER_GEMM_DLOPS_V1R3 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v1r3.hpp" + +template +__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = + GridwiseGemmDlops_km_kn_mn_v1r3; + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting"); + } + + const auto a_k0_m0_m1_k1_grid_desc = + GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); + const auto b_k0_n0_n1_k1_grid_desc = + GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); + + using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); + using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + { + std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc)); + DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc); + b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + + return ave_time; +#endif +} +#endif diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp new file mode 100644 index 0000000000..edfce52a19 --- /dev/null +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -0,0 +1,191 @@ +#ifndef DRIVER_GEMM_XDLOPS_V2R3 +#define DRIVER_GEMM_XDLOPS_V2R3 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +template +__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + { + std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " + << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); + + const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); + + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); + +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); + DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); + DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); + DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); + + a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); + b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); + c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); + c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); + + float ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); +#endif + return ave_time; +} +#endif diff --git a/host/driver_offline/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp similarity index 67% rename from host/driver_offline/conv_bwd_driver_offline.cpp rename to host/driver_offline/src/conv_bwd_driver_offline.cpp index 61c3fc385d..67cea94813 100644 --- a/host/driver_offline/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -12,10 +12,10 @@ #include "conv_common.hpp" #include "host_conv_bwd_data.hpp" #include "device_tensor.hpp" -#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" -#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" -#define USE_DYNAMIC_MODE 1 +#define USE_MODE 1 #define USE_CONV_BWD_V4R1_XDL_NHWC 1 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 @@ -37,7 +37,7 @@ int main(int argc, char* argv[]) constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; -#if USE_DYNAMIC_MODE +#if USE_MODE // dynamic mode if(argc != 22) { @@ -46,29 +46,29 @@ int main(int argc, char* argv[]) exit(1); } - const ConvTensorLayout layout = static_cast(atoi(argv[1])); - const ConvBackwardDataAlgo algo = static_cast(atoi(argv[2])); - const bool do_verification = atoi(argv[3]); - const int init_method = atoi(argv[4]); - const bool do_log = atoi(argv[5]); - const int nrepeat = atoi(argv[6]); + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardDataAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); - const index_t N = atoi(argv[7]); - const index_t K = atoi(argv[8]); - const index_t C = atoi(argv[9]); - const index_t Y = atoi(argv[10]); - const index_t X = atoi(argv[11]); - const index_t Hi = atoi(argv[12]); - const index_t Wi = atoi(argv[13]); + const index_t N = std::stoi(argv[7]); + const index_t K = std::stoi(argv[8]); + const index_t C = std::stoi(argv[9]); + const index_t Y = std::stoi(argv[10]); + const index_t X = std::stoi(argv[11]); + const index_t Hi = std::stoi(argv[12]); + const index_t Wi = std::stoi(argv[13]); - const index_t conv_stride_h = atoi(argv[14]); - const index_t conv_stride_w = atoi(argv[15]); - const index_t conv_dilation_h = atoi(argv[16]); - const index_t conv_dilation_w = atoi(argv[17]); - const index_t in_left_pad_h = atoi(argv[18]); - const index_t in_left_pad_w = atoi(argv[19]); - const index_t in_right_pad_h = atoi(argv[20]); - const index_t in_right_pad_w = atoi(argv[21]); + const index_t conv_stride_h = std::stoi(argv[14]); + const index_t conv_stride_w = std::stoi(argv[15]); + const index_t conv_dilation_h = std::stoi(argv[16]); + const index_t conv_dilation_w = std::stoi(argv[17]); + const index_t in_left_pad_h = std::stoi(argv[18]); + const index_t in_left_pad_w = std::stoi(argv[19]); + const index_t in_right_pad_h = std::stoi(argv[20]); + const index_t in_right_pad_w = std::stoi(argv[21]); const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1; @@ -83,12 +83,12 @@ int main(int argc, char* argv[]) exit(1); } - const ConvTensorLayout layout = static_cast(atoi(argv[1])); - const ConvBackwardDataAlgo algo = static_cast(atoi(argv[2])); - const bool do_verification = atoi(argv[3]); - const int init_method = atoi(argv[4]); - const bool do_log = atoi(argv[5]); - const int nrepeat = atoi(argv[6]); + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardDataAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); constexpr index_t N = 128; constexpr index_t C = 192; @@ -115,23 +115,19 @@ int main(int argc, char* argv[]) #endif #if 0 - constexpr index_t in_vector_size = 1; using in_data_t = float; using acc_data_t = float; using out_data_t = float; #elif 1 - constexpr index_t in_vector_size = 1; - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - switch(layout) + if(layout == ConvTensorLayout::NCHW) { - case ConvTensorLayout::NCHW: - // NCHW in_lengths_host[0] = static_cast(N); in_lengths_host[1] = static_cast(C); in_lengths_host[2] = static_cast(Hi); @@ -144,9 +140,9 @@ int main(int argc, char* argv[]) out_lengths_host[1] = static_cast(K); out_lengths_host[2] = static_cast(Ho); out_lengths_host[3] = static_cast(Wo); - break; - case ConvTensorLayout::NHWC: - // NHWC + } + else if(layout == ConvTensorLayout::NHWC) + { in_lengths_host[0] = static_cast(N); in_lengths_host[1] = static_cast(Hi); in_lengths_host[2] = static_cast(Wi); @@ -159,8 +155,10 @@ int main(int argc, char* argv[]) out_lengths_host[1] = static_cast(Ho); out_lengths_host[2] = static_cast(Wo); out_lengths_host[3] = static_cast(K); - break; - default: throw std::runtime_error("wrong! not implemented"); + } + else + { + throw std::runtime_error("wrong! not implemented"); } Tensor in_host(in_lengths_host); @@ -213,40 +211,8 @@ int main(int argc, char* argv[]) wei.GenerateTensorValue(gen_wei, num_thread); } - auto f_make_for_device_nchw = [&]() { -#if USE_DYNAMIC_MODE - const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); - const auto wei_lengths_dev = make_tuple(K, C, Y, X); - const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); - const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif - - return make_tuple(in_lengths_dev, - wei_lengths_dev, - out_lengths_dev, - conv_strides_dev, - conv_dilations_dev, - in_left_pads_dev, - in_right_pads_dev); - }; - auto f_make_for_device_nhwc = [&]() { -#if USE_DYNAMIC_MODE +#if USE_MODE const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); const auto wei_lengths_dev = make_tuple(K, Y, X, C); const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); @@ -277,8 +243,6 @@ int main(int argc, char* argv[]) in_right_pads_dev); }; - const auto nhwc_desc = f_make_for_device_nhwc(); - #if USE_CONV_BWD_V4R1_XDL_NHWC if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC) { @@ -289,20 +253,20 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nhwc(); - device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk< - in_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); + device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); } #endif @@ -316,20 +280,20 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nhwc(); - device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk< - in_data_t, - acc_data_t, - out_data_t>(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); + device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); } #endif diff --git a/host/driver_offline/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp similarity index 67% rename from host/driver_offline/conv_fwd_driver_offline.cpp rename to host/driver_offline/src/conv_fwd_driver_offline.cpp index ef2e16c4fa..32c33003c5 100644 --- a/host/driver_offline/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -12,17 +12,17 @@ #include "conv_common.hpp" #include "host_conv.hpp" #include "device_tensor.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" -#define USE_DYNAMIC_MODE 1 +#define USE_MODE 1 #define USE_CONV_FWD_V4R4_NCHW 1 #define USE_CONV_FWD_V4R4R2_NHWC 1 -#define USE_CONV_FWD_V6R1_NCHW 1 +#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 @@ -49,7 +49,7 @@ int main(int argc, char* argv[]) constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; -#if USE_DYNAMIC_MODE +#if USE_MODE // dynamic mode if(argc != 22) { @@ -58,29 +58,29 @@ int main(int argc, char* argv[]) exit(1); } - const ConvTensorLayout layout = static_cast(atoi(argv[1])); - const ConvForwardAlgo algo = static_cast(atoi(argv[2])); - const bool do_verification = atoi(argv[3]); - const int init_method = atoi(argv[4]); - const bool do_log = atoi(argv[5]); - const int nrepeat = atoi(argv[6]); + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvForwardAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); - const index_t N = atoi(argv[7]); - const index_t K = atoi(argv[8]); - const index_t C = atoi(argv[9]); - const index_t Y = atoi(argv[10]); - const index_t X = atoi(argv[11]); - const index_t Hi = atoi(argv[12]); - const index_t Wi = atoi(argv[13]); + const index_t N = std::stoi(argv[7]); + const index_t K = std::stoi(argv[8]); + const index_t C = std::stoi(argv[9]); + const index_t Y = std::stoi(argv[10]); + const index_t X = std::stoi(argv[11]); + const index_t Hi = std::stoi(argv[12]); + const index_t Wi = std::stoi(argv[13]); - const index_t conv_stride_h = atoi(argv[14]); - const index_t conv_stride_w = atoi(argv[15]); - const index_t conv_dilation_h = atoi(argv[16]); - const index_t conv_dilation_w = atoi(argv[17]); - const index_t in_left_pad_h = atoi(argv[18]); - const index_t in_left_pad_w = atoi(argv[19]); - const index_t in_right_pad_h = atoi(argv[20]); - const index_t in_right_pad_w = atoi(argv[21]); + const index_t conv_stride_h = std::stoi(argv[14]); + const index_t conv_stride_w = std::stoi(argv[15]); + const index_t conv_dilation_h = std::stoi(argv[16]); + const index_t conv_dilation_w = std::stoi(argv[17]); + const index_t in_left_pad_h = std::stoi(argv[18]); + const index_t in_left_pad_w = std::stoi(argv[19]); + const index_t in_right_pad_h = std::stoi(argv[20]); + const index_t in_right_pad_w = std::stoi(argv[21]); const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1; @@ -95,12 +95,12 @@ int main(int argc, char* argv[]) exit(1); } - const ConvTensorLayout layout = static_cast(atoi(argv[1])); - const ConvForwardAlgo algo = static_cast(atoi(argv[2])); - const bool do_verification = atoi(argv[3]); - const int init_method = atoi(argv[4]); - const bool do_log = atoi(argv[5]); - const int nrepeat = atoi(argv[6]); + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvForwardAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); constexpr index_t N = 128; constexpr index_t C = 192; @@ -142,10 +142,8 @@ int main(int argc, char* argv[]) std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - switch(layout) + if(layout == ConvTensorLayout::NCHW) { - case ConvTensorLayout::NCHW: - // NCHW in_lengths_host[0] = static_cast(N); in_lengths_host[1] = static_cast(C); in_lengths_host[2] = static_cast(Hi); @@ -158,9 +156,9 @@ int main(int argc, char* argv[]) out_lengths_host[1] = static_cast(K); out_lengths_host[2] = static_cast(Ho); out_lengths_host[3] = static_cast(Wo); - break; - case ConvTensorLayout::NHWC: - // NHWC + } + else if(layout == ConvTensorLayout::NHWC) + { in_lengths_host[0] = static_cast(N); in_lengths_host[1] = static_cast(Hi); in_lengths_host[2] = static_cast(Wi); @@ -173,8 +171,10 @@ int main(int argc, char* argv[]) out_lengths_host[1] = static_cast(Ho); out_lengths_host[2] = static_cast(Wo); out_lengths_host[3] = static_cast(K); - break; - default: throw std::runtime_error("wrong! not implemented"); + } + else + { + std::runtime_error("wrong! not implemented"); } Tensor in(in_lengths_host); @@ -228,7 +228,7 @@ int main(int argc, char* argv[]) } auto f_make_for_device_nchw = [&]() { -#if USE_DYNAMIC_MODE +#if USE_MODE const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); const auto wei_lengths_dev = make_tuple(K, C, Y, X); const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); @@ -260,7 +260,7 @@ int main(int argc, char* argv[]) }; auto f_make_for_device_nhwc = [&]() { -#if USE_DYNAMIC_MODE +#if USE_MODE const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); const auto wei_lengths_dev = make_tuple(K, Y, X, C); const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); @@ -301,20 +301,19 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); + device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); } #endif @@ -328,20 +327,19 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nhwc(); - device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); + device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); } #endif @@ -355,20 +353,19 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); + device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); } #endif @@ -382,21 +379,20 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); + device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); } #endif @@ -410,9 +406,9 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( tmp[I0], tmp[I1], tmp[I2], @@ -437,9 +433,9 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nhwc(); - device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( tmp[I0], tmp[I1], tmp[I2], @@ -467,7 +463,6 @@ int main(int argc, char* argv[]) check_error(out_host, out_device); -#if 0 if(do_log) { LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; @@ -475,6 +470,5 @@ int main(int argc, char* argv[]) LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; } -#endif } } diff --git a/host/driver_online/CMakeLists.txt b/host/driver_online/CMakeLists.txt deleted file mode 100644 index 077e3218a0..0000000000 --- a/host/driver_online/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -include_directories(BEFORE - include - ${PROJECT_BINARY_DIR}/host/online_compile/include - ${PROJECT_SOURCE_DIR}/host/online_compile/include - ${PROJECT_SOURCE_DIR}/host/host_tensor/include - ${PROJECT_SOURCE_DIR}/host/solver/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation - ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform - ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver - ${PROJECT_SOURCE_DIR}/external/rocm/include - ${PROJECT_SOURCE_DIR}/external/half/include -) - -set(CONV_FWD_DRIVER_ONLINE_SOURCE conv_fwd_driver_online.cpp) - -add_executable(conv_fwd_driver_online ${CONV_FWD_DRIVER_ONLINE_SOURCE}) - -target_link_libraries(conv_fwd_driver_online PRIVATE host_tensor) -target_link_libraries(conv_fwd_driver_online PRIVATE online_compile) diff --git a/host/driver_online/conv_fwd_driver_online.cpp b/host/driver_online/conv_fwd_driver_online.cpp deleted file mode 100644 index 29609d5474..0000000000 --- a/host/driver_online/conv_fwd_driver_online.cpp +++ /dev/null @@ -1,453 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "host_conv.hpp" -#include "device_tensor.hpp" -#include "handle.hpp" -#include "hipCheck.hpp" -#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" -#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" -#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" -#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" - -#define USE_CONV_FWD_V4R4_NCHW 1 -#define USE_CONV_FWD_V6R1_NCHW 1 -#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1 -#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1 - -enum ConvForwardAlgo -{ - V4R4NCHW, // 0 - V6R1NCHW, // 1 - V4R4XDLNCHW, // 2 - V4R4XDLNHWC // 3 -}; - -int main(int argc, char* argv[]) -{ - using namespace ck; - using namespace ck_driver; - using size_t = std::size_t; - - hipStream_t stream; - online_compile::Handle* handle; - - MY_HIP_CHECK(hipStreamCreate(&stream)); - - handle = new online_compile::Handle(stream); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - - if(argc != 22) - { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); - printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); - exit(1); - } - - const ConvTensorLayout layout = static_cast(atoi(argv[1])); - const ConvForwardAlgo algo = static_cast(atoi(argv[2])); - const bool do_verification = atoi(argv[3]); - const int init_method = atoi(argv[4]); - const bool do_log = atoi(argv[5]); - const int nrepeat = atoi(argv[6]); - - const index_t N = atoi(argv[7]); - const index_t K = atoi(argv[8]); - const index_t C = atoi(argv[9]); - const index_t Y = atoi(argv[10]); - const index_t X = atoi(argv[11]); - const index_t Hi = atoi(argv[12]); - const index_t Wi = atoi(argv[13]); - - const index_t conv_stride_h = atoi(argv[14]); - const index_t conv_stride_w = atoi(argv[15]); - const index_t conv_dilation_h = atoi(argv[16]); - const index_t conv_dilation_w = atoi(argv[17]); - const index_t in_left_pad_h = atoi(argv[18]); - const index_t in_left_pad_w = atoi(argv[19]); - const index_t in_right_pad_h = atoi(argv[20]); - const index_t in_right_pad_w = atoi(argv[21]); - - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; - - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - -#if 1 - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; -#elif 0 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; -#elif 1 - using in_data_t = int8_t; - using acc_data_t = int32_t; - using out_data_t = int8_t; -#endif - - std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - - switch(layout) - { - case ConvTensorLayout::NCHW: - // NCHW - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(C); - in_lengths_host[2] = static_cast(Hi); - in_lengths_host[3] = static_cast(Wi); - - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(C); - wei_lengths_host[2] = static_cast(Y); - wei_lengths_host[3] = static_cast(X); - - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(K); - out_lengths_host[2] = static_cast(Ho); - out_lengths_host[3] = static_cast(Wo); - break; - case ConvTensorLayout::NHWC: - // NHWC - in_lengths_host[0] = static_cast(N); - in_lengths_host[1] = static_cast(Hi); - in_lengths_host[2] = static_cast(Wi); - in_lengths_host[3] = static_cast(C); - - wei_lengths_host[0] = static_cast(K); - wei_lengths_host[1] = static_cast(Y); - wei_lengths_host[2] = static_cast(X); - wei_lengths_host[3] = static_cast(C); - - out_lengths_host[0] = static_cast(N); - out_lengths_host[1] = static_cast(Ho); - out_lengths_host[2] = static_cast(Wo); - out_lengths_host[3] = static_cast(K); - break; - default: throw std::runtime_error("wrong! not implemented"); - } - - Tensor in(in_lengths_host); - Tensor wei(wei_lengths_host); - Tensor out_host(out_lengths_host); - Tensor out_device(out_lengths_host); - - std::cout << "layout: " << layout << std::endl; - ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); - ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); - ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); - print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); - print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); - print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); - print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); - - std::size_t num_thread = std::thread::hardware_concurrency(); - - switch(init_method) - { - case 0: - // no initialization - break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 3: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 4: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 5: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei.GenerateTensorValue(gen_wei, num_thread); - } - - auto f_make_for_device_nchw = [&]() { - const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); - const auto wei_lengths_dev = make_tuple(K, C, Y, X); - const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); - - return make_tuple(in_lengths_dev, wei_lengths_dev, out_lengths_dev); - }; - - auto f_make_for_device_nhwc = [&]() { - const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); - const auto wei_lengths_dev = make_tuple(K, Y, X, C); - const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); - - return make_tuple(in_lengths_dev, wei_lengths_dev, out_lengths_dev); - }; - - const auto conv_strides = make_tuple(conv_stride_h, conv_stride_w); - const auto conv_dilations = make_tuple(conv_dilation_h, conv_dilation_w); - const auto in_left_pads = make_tuple(in_left_pad_h, in_left_pad_w); - const auto in_right_pads = make_tuple(in_right_pad_h, in_right_pad_w); - -#if USE_CONV_FWD_V4R4_NCHW - if(algo == ConvForwardAlgo::V4R4NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable = - &default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw; - - online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw< - in_data_t, - acc_data_t, - out_data_t>(handle, - tmp[I0], - tmp[I1], - tmp[I2], - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - in, - wei, - out_device, - tunable, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V6R1_NCHW - if(algo == ConvForwardAlgo::V6R1NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - -#if 1 - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { - get_datatype_enum_from_type::value, - get_datatype_enum_from_type::value, - get_datatype_enum_from_type::value, - 256, - 4, - 1, - 128, - 32, - 8, - 4, - 4, - 1, - {8, 2}, - {8, 2}, - {4, 1, 1, 1, 1}, - {2, 1, 1, 128, 1}, - {4, 1, 1, 1, 1}, - {1, 1, 1, 1, 1}, - {1, 4, 1, 1, 1}, - {8, 1, 1, 32, 1}, - {1, 1, 1, 1, 1}, - {1, 1, 1, 1, 1}, - 4, - true, - true}; -#elif 0 - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { - get_datatype_enum_from_type::value, - get_datatype_enum_from_type::value, - get_datatype_enum_from_type::value, - 256, - 4, - 2, - 128, - 32, - 8, - 4, - 4, - 1, - {8, 2}, - {8, 2}, - {4, 1, 1, 1, 2}, - {2, 1, 1, 128, 1}, - {4, 1, 1, 1, 1}, - {1, 1, 1, 1, 1}, - {1, 4, 1, 1, 2}, - {8, 1, 1, 32, 1}, - {1, 1, 1, 1, 1}, - {1, 1, 1, 1, 1}, - 4, - true, - true}; -#elif 1 - const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = { - get_datatype_enum_from_type::value, - get_datatype_enum_from_type::value, - get_datatype_enum_from_type::value, - 256, - 4, - 4, - 128, - 32, - 8, - 4, - 4, - 1, - {8, 2}, - {8, 2}, - {4, 1, 1, 1, 4}, - {2, 1, 1, 128, 1}, - {4, 1, 1, 1, 1}, - {1, 1, 1, 1, 1}, - {1, 4, 1, 1, 4}, - {8, 1, 1, 32, 1}, - {1, 1, 1, 1, 1}, - {1, 1, 1, 1, 1}, - 4, - true, - true}; -#endif - - online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw< - in_data_t, - acc_data_t, - out_data_t>(handle, - tmp[I0], - tmp[I1], - tmp[I2], - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - in, - wei, - out_device, - compile_param, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4_XDLOPS_NCHW - if(algo == ConvForwardAlgo::V4R4XDLNCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable = - &default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; - - online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw< - in_data_t, - acc_data_t, - out_data_t>(handle, - tmp[I0], - tmp[I1], - tmp[I2], - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - in, - wei, - out_device, - tunable, - nrepeat); - } -#endif - -#if USE_CONV_FWD_V4R4_XDLOPS_NHWC - if(algo == ConvForwardAlgo::V4R4XDLNHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable = - &default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk; - - online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk< - in_data_t, - acc_data_t, - out_data_t>(handle, - tmp[I0], - tmp[I1], - tmp[I2], - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - in, - wei, - out_device, - tunable, - nrepeat); - } -#endif - - if(do_verification) - { - host_direct_convolution(in, - wei, - out_host, - make_tuple(conv_stride_h, conv_stride_w), - make_tuple(conv_dilation_h, conv_dilation_w), - make_tuple(in_left_pad_h, in_left_pad_w), - make_tuple(in_right_pad_h, in_right_pad_w), - layout); - - check_error(out_host, out_device); - -#if 0 - if(do_log) - { - LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; - LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; - LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; - } -#endif - } - - delete handle; - MY_HIP_CHECK(hipStreamDestroy(stream)); -} diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 06412fba0b..0000000000 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,395 +0,0 @@ -#pragma once -#include "device.hpp" -#include "host_tensor.hpp" -#include "handle.hpp" -#include "online_driver_common.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp" - -namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw { - -template -static std::string get_network_config_string_from_types() -{ - using namespace ck; - - std::string out; - - out += std::to_string(get_datatype_enum_from_type::value) + "_" + - std::to_string(get_datatype_enum_from_type::value) + "_" + - std::to_string(get_datatype_enum_from_type::value); - - return (out); -}; - -static std::string -get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt) -{ - std::string out("TUN_"); - - out += std::to_string(pt->BlockSize) + "_"; - - out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" + - std::to_string(pt->KPerBlock) + "_"; - out += std::to_string(pt->M1PerThread) + "x" + std::to_string(pt->N1PerThread) + "x" + - std::to_string(pt->KPerThread) + "_"; - out += std::to_string(pt->M1N1ThreadClusterM10) + "x" + - std::to_string(pt->M1N1ThreadClusterN10) + "x" + - std::to_string(pt->M1N1ThreadClusterM11) + "x" + - std::to_string(pt->M1N1ThreadClusterN11) + "_"; - - out += std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[0]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[1]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[2]) + "_"; - - out += std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[0]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[1]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[2]) + "_"; - - out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_"; - - out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_"; - - out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; - out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; - out += std::to_string(pt->ABlockTransferDstScalarPerVector_M1) + "_"; - out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; - - out += std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[0]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[1]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[2]) + "_"; - - out += std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[0]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[1]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[2]) + "_"; - - out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_"; - - out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_"; - - out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; - out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; - out += std::to_string(pt->BBlockTransferDstScalarPerVector_N1) + "_"; - out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; - - out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "_"; - - out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; - out += std::to_string(pt->CThreadTransferDstScalarPerVector); - - return (out); -}; - -template -static std::string get_definition_string_from_types() -{ - using namespace ck; - - std::string out; - - out += - " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + - " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + - " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); - - return (out); -}; - -static std::string -get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt) -{ - std::string out; - - out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); - - out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) + - " -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) + - " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); - out += " -DCK_PARAM_M1PerThread=" + std::to_string(pt->M1PerThread) + - " -DCK_PARAM_N1PerThread=" + std::to_string(pt->N1PerThread) + - " -DCK_PARAM_KPerThread=" + std::to_string(pt->KPerThread); - - out += " -DCK_PARAM_M1N1ThreadClusterM10=" + std::to_string(pt->M1N1ThreadClusterM10) + - " -DCK_PARAM_M1N1ThreadClusterN10=" + std::to_string(pt->M1N1ThreadClusterN10) + - " -DCK_PARAM_M1N1ThreadClusterM11=" + std::to_string(pt->M1N1ThreadClusterM11) + - " -DCK_PARAM_M1N1ThreadClusterN11=" + std::to_string(pt->M1N1ThreadClusterN11); - - out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1=" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[0]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[1]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[2]); - - out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1=" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[0]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[1]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[2]); - - out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]); - - out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + - std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[2]); - - out += - " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); - out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + - std::to_string(pt->ABlockTransferSrcScalarPerVector); - out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_M1=" + - std::to_string(pt->ABlockTransferDstScalarPerVector_M1); - out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + - std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); - - out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1=" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[0]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[1]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[2]); - - out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1=" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[0]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[1]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[2]); - - out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]); - - out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + - std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[2]); - - out += - " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); - out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + - std::to_string(pt->BBlockTransferSrcScalarPerVector); - out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_N1=" + - std::to_string(pt->BBlockTransferDstScalarPerVector_N1); - out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + - std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); - - out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]); - - out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + - std::to_string(pt->CThreadTransferSrcDstVectorDim); - out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + - std::to_string(pt->CThreadTransferDstScalarPerVector); - - return (out); -}; - -} // namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw - -template -void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( - online_compile::Handle* handle, - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable, - ck::index_t nrepeat) -{ - using namespace ck; - using namespace ck_driver; - using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw; - using size_t = std::size_t; - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////// - // The follow codes are only used for computing the grid_size, hasMainKBlockLoop, - // hasDoubleTailKBlockLoop - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); - - const auto descs = - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - const auto a_k_m_grid_desc = descs[I0]; - const auto c_m_n_grid_desc = descs[I2]; - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - const auto K = a_k_m_grid_desc.GetLength(I0); - - const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock); - const bool hasMainKBlockLoop = ((K + tunable->KPerBlock) / (2 * tunable->KPerBlock) > 1); - const bool hasDoubleTailKBlockLoop = ((K / tunable->KPerBlock) % 2 == 0); - ///////////////////////////////////////////////////////////////////////////////////////////////////////////// - - // these buffers are usually provided by the user application - DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - // these are workspace buffers that should be expressed to the user by the corresponding - // workspace API - DeviceMem workspace_buf(4096); - - void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); - void* b_k_n0_n1_grid_desc_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); - void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); - void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); - - const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; - const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; - const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; - - std::string program_name = - "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp"; - std::string algo_name = "implicit_gemm_conv_fwd_v4r4_dlops_nchw"; - - std::string param = " -std=c++17 "; - std::string network_config; - - param += get_definition_string_from_types() + " " + - get_definition_string_from_tunable(tunable) + - " -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) + - " -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop); - network_config = get_network_config_string_from_types() + "_" + - get_network_config_string_from_tunable(tunable) + "_" + - std::to_string(hasMainKBlockLoop) + "_" + - std::to_string(hasDoubleTailKBlockLoop); - - std::vector kernel1_times; - std::vector kernel2_times; - - for(index_t i = 0; i < nrepeat; ++i) - { - KernelTimer timer1, timer2; - std::string kernel_name; - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare"; - auto network_config_1 = network_config + "_1"; - - timer1.Start(); - handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( - static_cast(in_n_c_hi_wi_lengths[I0]), - static_cast(in_n_c_hi_wi_lengths[I1]), - static_cast(in_n_c_hi_wi_lengths[I2]), - static_cast(in_n_c_hi_wi_lengths[I3]), - static_cast(wei_k_c_y_x_lengths[I0]), - static_cast(wei_k_c_y_x_lengths[I2]), - static_cast(wei_k_c_y_x_lengths[I3]), - conv_strides[I0], - conv_strides[I1], - conv_dilations[I0], - conv_dilations[I1], - in_left_pads[I0], - in_left_pads[I1], - in_right_pads[I0], - in_right_pads[I1], - a_k_m0_m1_grid_desc_dev_buf, - b_k_n0_n1_grid_desc_dev_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf, - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); - timer1.End(); - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw"; - auto network_config_2 = network_config + "_2"; - - timer2.Start(); - handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( - reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), - reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), - reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), - (const void*)(a_k_m0_m1_grid_desc_dev_buf), - (const void*)(b_k_n0_n1_grid_desc_dev_buf), - (const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf), - (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); - timer2.End(); - - kernel1_times.push_back(timer1.GetElapsedTime()); - kernel2_times.push_back(timer2.GetElapsedTime()); - } - - { - auto ave_time1 = - std::accumulate( - std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / - (nrepeat - 1); - auto ave_time2 = - std::accumulate( - std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / - (nrepeat - 1); - - const auto N = in_n_c_hi_wi_lengths[I0]; - const auto C = in_n_c_hi_wi_lengths[I1]; - - const auto K = out_n_k_ho_wo_lengths[I1]; - const auto Ho = out_n_k_ho_wo_lengths[I2]; - const auto Wo = out_n_k_ho_wo_lengths[I3]; - - const auto Y = wei_k_c_y_x_lengths[I2]; - const auto X = wei_k_c_y_x_lengths[I3]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); - - std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " - << ave_time2 << "), " << perf << " TFlop/s" << std::endl; - }; - - // copy result back to host - out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 61ce41fe84..0000000000 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,386 +0,0 @@ -#include "device.hpp" -#include "host_tensor.hpp" -#include "handle.hpp" -#include "online_driver_common.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp" - -namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw { - -template -static std::string get_network_config_string_from_types() -{ - using namespace ck; - - std::string out; - - out += std::to_string(get_datatype_enum_from_type::value) + "_" + - std::to_string(get_datatype_enum_from_type::value) + "_" + - std::to_string(get_datatype_enum_from_type::value); - - return (out); -}; - -static std::string -get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt) -{ - std::string out("TUN_"); - - out += std::to_string(pt->BlockSize) + "_"; - - out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" + - std::to_string(pt->KPerBlock) + "_"; - out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" + - std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" + - std::to_string(pt->K1) + "_"; - - out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_"; - - out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_"; - - out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_"; - - out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_"; - - out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; - out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; - out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_"; - out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; - - out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_"; - - out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_"; - - out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_"; - - out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_"; - - out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; - out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; - out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_"; - out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; - - out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_"; - - out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; - out += std::to_string(pt->CThreadTransferDstScalarPerVector); - - return (out); -}; - -template -static std::string get_definition_string_from_types() -{ - using namespace ck; - - std::string out; - - out += - " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + - " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + - " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); - - return (out); -}; - -static std::string -get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt) -{ - std::string out; - - out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); - - out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) + - " -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) + - " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); - out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) + - " -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) + - " -DCK_PARAM_K1=" + std::to_string(pt->K1) + - " -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) + - " -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat); - - out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]); - - out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]); - - out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]); - - out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + - std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[2]); - - out += - " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); - out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + - std::to_string(pt->ABlockTransferSrcScalarPerVector); - out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" + - std::to_string(pt->ABlockTransferDstScalarPerVector_K1); - out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + - std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); - - out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]); - - out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]); - - out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]); - - out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + - std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[2]); - - out += - " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); - out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + - std::to_string(pt->BBlockTransferSrcScalarPerVector); - out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" + - std::to_string(pt->BBlockTransferDstScalarPerVector_K1); - out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + - std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); - - out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]); - - out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + - std::to_string(pt->CThreadTransferSrcDstVectorDim); - out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + - std::to_string(pt->CThreadTransferDstScalarPerVector); - - return (out); -}; - -} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw - -template -void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - online_compile::Handle* handle, - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable, - ck::index_t nrepeat) -{ - using namespace ck; - using namespace ck_driver; - using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; - using size_t = std::size_t; - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////// - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); - - const auto n = in_n_c_hi_wi_desc.GetLength(I0); - const auto c = in_n_c_hi_wi_desc.GetLength(I1); - const auto hi = in_n_c_hi_wi_desc.GetLength(I2); - const auto wi = in_n_c_hi_wi_desc.GetLength(I3); - const auto k = wei_k_c_y_x_desc.GetLength(I0); - const auto y = wei_k_c_y_x_desc.GetLength(I2); - const auto x = wei_k_c_y_x_desc.GetLength(I3); - const auto ho = out_n_k_ho_wo_desc.GetLength(I2); - const auto wo = out_n_k_ho_wo_desc.GetLength(I3); - - const auto M = k; - const auto N = n * ho * wo; - const auto K = c * y * x; - const auto K0 = K / tunable->K1; - - const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock); - ///////////////////////////////////////////////////////////////////////////////////////////////////////////// - - // these buffers are usually provided by the user application - DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - // these are workspace buffers that should be expressed to the user by the corresponding - // workspace API - DeviceMem workspace_buf(4096); - - void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); - void* b_k_n0_n1_grid_desc_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); - void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); - void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); - - const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; - const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; - const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; - - std::string program_name = - "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp"; - std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nchw"; - - std::string param = " -std=c++17 "; - std::string network_config; - - param += get_definition_string_from_types() + " " + " -DCK_USE_AMD_XDLOPS" + - get_definition_string_from_tunable(tunable); - - network_config = get_network_config_string_from_types() + "_" + - get_network_config_string_from_tunable(tunable); - - std::vector kernel1_times; - std::vector kernel2_times; - - for(index_t i = 0; i < nrepeat; ++i) - { - KernelTimer timer1, timer2; - std::string kernel_name; - - kernel_name = - "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare"; - auto network_config_1 = network_config + "_1"; - - timer1.Start(); - handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( - static_cast(in_n_c_hi_wi_lengths[I0]), - static_cast(in_n_c_hi_wi_lengths[I1]), - static_cast(in_n_c_hi_wi_lengths[I2]), - static_cast(in_n_c_hi_wi_lengths[I3]), - static_cast(wei_k_c_y_x_lengths[I0]), - static_cast(wei_k_c_y_x_lengths[I2]), - static_cast(wei_k_c_y_x_lengths[I3]), - conv_strides[I0], - conv_strides[I1], - conv_dilations[I0], - conv_dilations[I1], - in_left_pads[I0], - in_left_pads[I1], - in_right_pads[I0], - in_right_pads[I1], - a_k_m0_m1_grid_desc_dev_buf, - b_k_n0_n1_grid_desc_dev_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf, - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); - timer1.End(); - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw"; - auto network_config_2 = network_config + "_2"; - - timer2.Start(); - handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( - reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), - reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), - reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), - (const void*)(a_k_m0_m1_grid_desc_dev_buf), - (const void*)(b_k_n0_n1_grid_desc_dev_buf), - (const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf), - (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); - timer2.End(); - - kernel1_times.push_back(timer1.GetElapsedTime()); - kernel2_times.push_back(timer2.GetElapsedTime()); - } - - { - auto ave_time1 = - std::accumulate( - std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / - (nrepeat - 1); - auto ave_time2 = - std::accumulate( - std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / - (nrepeat - 1); - - const auto N = in_n_c_hi_wi_lengths[I0]; - const auto C = in_n_c_hi_wi_lengths[I1]; - - const auto K = out_n_k_ho_wo_lengths[I1]; - const auto Ho = out_n_k_ho_wo_lengths[I2]; - const auto Wo = out_n_k_ho_wo_lengths[I3]; - - const auto Y = wei_k_c_y_x_lengths[I2]; - const auto X = wei_k_c_y_x_lengths[I3]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); - - std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " - << ave_time2 << "), " << perf << " TFlop/s" << std::endl; - }; - - // copy result back to host - out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 57724c7612..0000000000 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,389 +0,0 @@ -#include "device.hpp" -#include "host_tensor.hpp" -#include "handle.hpp" -#include "online_driver_common.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" -#include "conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" - -namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk { - -template -static std::string get_network_config_string_from_types() -{ - using namespace ck; - - std::string out; - - out += std::to_string(get_datatype_enum_from_type::value) + "_" + - std::to_string(get_datatype_enum_from_type::value) + "_" + - std::to_string(get_datatype_enum_from_type::value); - - return (out); -}; - -static std::string -get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt) -{ - std::string out("TUN_"); - - out += std::to_string(pt->BlockSize) + "_"; - - out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" + - std::to_string(pt->KPerBlock) + "_"; - out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" + - std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" + - std::to_string(pt->K1) + "_"; - - out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_"; - - out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_"; - - out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_"; - - out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_"; - - out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; - out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; - out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_"; - out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; - - out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_"; - - out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_"; - - out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_"; - - out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_"; - - out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; - out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; - out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_"; - out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; - - out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_"; - - out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; - out += std::to_string(pt->CThreadTransferDstScalarPerVector); - - return (out); -}; - -template -static std::string get_definition_string_from_types() -{ - using namespace ck; - - std::string out; - - out += - " -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + - " -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value) + - " -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type::value); - - return (out); -}; - -static std::string -get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt) -{ - std::string out; - - out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); - - out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) + - " -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) + - " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); - out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) + - " -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) + - " -DCK_PARAM_K1=" + std::to_string(pt->K1) + - " -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) + - " -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat); - - out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]); - - out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]); - - out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]); - - out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + - std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[2]); - - out += - " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); - out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + - std::to_string(pt->ABlockTransferSrcScalarPerVector); - out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" + - std::to_string(pt->ABlockTransferDstScalarPerVector_K1); - out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + - std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); - - out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]); - - out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]); - - out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]); - - out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + - std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[2]); - - out += - " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); - out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + - std::to_string(pt->BBlockTransferSrcScalarPerVector); - out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" + - std::to_string(pt->BBlockTransferDstScalarPerVector_K1); - out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + - std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); - - out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]); - - out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + - std::to_string(pt->CThreadTransferSrcDstVectorDim); - out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + - std::to_string(pt->CThreadTransferDstScalarPerVector); - - return (out); -}; - -} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk - -template -void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( - online_compile::Handle* handle, - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable, - ck::index_t nrepeat) -{ - using namespace ck; - using namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk; - using size_t = std::size_t; - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////// - // The follow codes are only used for computing the grid_size, hasMainKBlockLoop, - // hasDoubleTailKBlockLoop - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); - - const auto n = in_n_hi_wi_c_desc.GetLength(I0); - const auto hi = in_n_hi_wi_c_desc.GetLength(I1); - const auto wi = in_n_hi_wi_c_desc.GetLength(I2); - const auto c = in_n_hi_wi_c_desc.GetLength(I3); - - const auto k = wei_k_y_x_c_desc.GetLength(I0); - const auto y = wei_k_y_x_c_desc.GetLength(I1); - const auto x = wei_k_y_x_c_desc.GetLength(I2); - - const auto ho = out_n_ho_wo_k_desc.GetLength(I1); - const auto wo = out_n_ho_wo_k_desc.GetLength(I2); - - const auto M = k; - const auto N = n * ho * wo; - const auto K = c * y * x; - const auto K0 = K / tunable->K1; - - const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock); - - // these buffers are usually provided by the user application - DeviceMem in_n_hi_wi_c_dev_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_dev_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_dev_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_dev_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_dev_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_dev_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - // these are workspace buffers that should be expressed to the user by the corresponding - // workspace API - DeviceMem workspace_buf(4096); - - void* a_k0_m_k1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); - void* b_k0_n_k1_grid_desc_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); - void* c_m0_m1_m2_n_grid_desc_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); - void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); - - const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; - const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; - const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; - - std::string program_name = - "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp"; - std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nhwc"; - - std::string param = " -std=c++17 "; - std::string network_config; - - param += get_definition_string_from_types() + " -DCK_USE_AMD_XDLOPS "; - param += get_definition_string_from_tunable(tunable); - - network_config = get_network_config_string_from_types() + "_" + - get_network_config_string_from_tunable(tunable); - - std::vector kernel1_times; - std::vector kernel2_times; - - for(index_t i = 0; i < nrepeat; ++i) - { - KernelTimer timer1, timer2; - std::string kernel_name; - - kernel_name = - "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare"; - auto network_config_1 = network_config + "_1"; - - timer1.Start(); - handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( - static_cast(in_n_hi_wi_c_lengths[I0]), - static_cast(in_n_hi_wi_c_lengths[I1]), - static_cast(in_n_hi_wi_c_lengths[I2]), - static_cast(in_n_hi_wi_c_lengths[I3]), - static_cast(wei_k_y_x_c_lengths[I0]), - static_cast(wei_k_y_x_c_lengths[I1]), - static_cast(wei_k_y_x_c_lengths[I2]), - conv_strides[I0], - conv_strides[I1], - conv_dilations[I0], - conv_dilations[I1], - in_left_pads[I0], - in_left_pads[I1], - in_right_pads[I0], - in_right_pads[I1], - a_k0_m_k1_grid_desc_dev_buf, - b_k0_n_k1_grid_desc_dev_buf, - c_m0_m1_m2_n_grid_desc_dev_buf, - c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf); - timer1.End(); - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk"; - auto network_config_2 = network_config + "_2"; - - timer2.Start(); - handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( - reinterpret_cast(in_n_hi_wi_c_dev_buf.GetDeviceBuffer()), - reinterpret_cast(wei_k_y_x_c_dev_buf.GetDeviceBuffer()), - reinterpret_cast(out_n_ho_wo_k_dev_buf.GetDeviceBuffer()), - (const void*)(a_k0_m_k1_grid_desc_dev_buf), - (const void*)(b_k0_n_k1_grid_desc_dev_buf), - (const void*)(c_m0_m1_m2_n_grid_desc_dev_buf), - (const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf)); - timer2.End(); - - kernel1_times.push_back(timer1.GetElapsedTime()); - kernel2_times.push_back(timer2.GetElapsedTime()); - } - - { - auto ave_time1 = - std::accumulate( - std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / - (nrepeat - 1); - auto ave_time2 = - std::accumulate( - std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / - (nrepeat - 1); - - const auto N = in_n_hi_wi_c_lengths[I0]; - const auto C = in_n_hi_wi_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - const auto K = out_n_ho_wo_k_lengths[I3]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time2; - - std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " - << ave_time2 << "), " << perf << " TFlop/s" << std::endl; - }; - - // copy result back to host - out_n_ho_wo_k_dev_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 92467a7668..0000000000 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,182 +0,0 @@ -#pragma once -#include "device.hpp" -#include "host_tensor.hpp" -#include "handle.hpp" -#include "online_driver_common.hpp" -#include "convolution_problem_descriptor.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" -#include "conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp" - -template -void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( - online_compile::Handle* handle, - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - const ck_driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param, - ck::index_t nrepeat) -{ - using namespace ck; - using namespace ck_driver; - using size_t = std::size_t; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - ConvolutionProblemDescriptor conv_problem_desc{in_n_c_hi_wi_lengths[I0], - out_n_k_ho_wo_lengths[I1], - in_n_c_hi_wi_lengths[I1], - wei_k_c_y_x_lengths[I2], - wei_k_c_y_x_lengths[I3], - in_n_c_hi_wi_lengths[I2], - in_n_c_hi_wi_lengths[I3], - out_n_k_ho_wo_lengths[I2], - out_n_k_ho_wo_lengths[I3], - conv_strides[I0], - conv_strides[I1], - conv_dilations[I0], - conv_dilations[I1], - in_left_pads[I0], - in_left_pads[I1], - in_right_pads[I0], - in_right_pads[I1], - get_datatype_enum_from_type::value, - get_datatype_enum_from_type::value, - get_datatype_enum_from_type::value}; - - if(!ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::IsValidCompileParameter(conv_problem_desc, - compile_param)) - { - throw std::runtime_error("wrong! IsValidCompileParameter fail"); - } - - DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - // workspace is used for save transformed tensor descritpors created by prepare kernel - DeviceMem workspace_dev_buf( - ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetWorkSpaceSize(conv_problem_desc, compile_param)); - - const auto block_size = std::size_t( - ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetBlockSize(conv_problem_desc, compile_param)); - - const auto grid_size = std::size_t( - ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetGridSize(conv_problem_desc, compile_param)); - - const std::vector vld1 = {1, 1, 1}; - const std::vector vgd1 = {1, 1, 1}; - - const std::vector vld2 = {static_cast(block_size), 1, 1}; - const std::vector vgd2 = {static_cast(grid_size * block_size), 1, 1}; - - std::string program_name = - "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp"; - std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw"; - - std::string compile_param_string = get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString(); - std::string network_config = compile_param_string; - - std::vector kernel1_times; - std::vector kernel2_times; - - for(index_t i = 0; i < nrepeat + 1; ++i) - { - KernelTimer timer1, timer2; - std::string kernel_name; - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare"; - auto network_config_1 = network_config + "_1"; - - timer1.Start(); - handle->AddKernel(algo_name, - network_config_1, - program_name, - kernel_name, - vld1, - vgd1, - compile_param_string)(static_cast(in_n_c_hi_wi_lengths[I0]), - static_cast(in_n_c_hi_wi_lengths[I1]), - static_cast(in_n_c_hi_wi_lengths[I2]), - static_cast(in_n_c_hi_wi_lengths[I3]), - static_cast(wei_k_c_y_x_lengths[I0]), - static_cast(wei_k_c_y_x_lengths[I2]), - static_cast(wei_k_c_y_x_lengths[I3]), - conv_strides[I0], - conv_strides[I1], - conv_dilations[I0], - conv_dilations[I1], - in_left_pads[I0], - in_left_pads[I1], - in_right_pads[I0], - in_right_pads[I1], - (void*)(workspace_dev_buf.GetDeviceBuffer())); - timer1.End(); - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw"; - auto network_config_2 = network_config + "_2"; - - timer2.Start(); - handle->AddKernel(algo_name, - network_config_2, - program_name, - kernel_name, - vld2, - vgd2, - compile_param_string)( - reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), - reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), - reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), - (const void*)(workspace_dev_buf.GetDeviceBuffer())); - timer2.End(); - - kernel1_times.push_back(timer1.GetElapsedTime()); - kernel2_times.push_back(timer2.GetElapsedTime()); - } - - { - auto ave_time1 = - std::accumulate( - std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus{}) / - nrepeat; - auto ave_time2 = - std::accumulate( - std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus{}) / - nrepeat; - - float perf = (float)(conv_problem_desc.CalculateFlop()) / - (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); - - std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " - << ave_time2 << "), " << perf << " TFlop/s" << std::endl; - }; - - // copy result back to host - out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/host_tensor/CMakeLists.txt b/host/host_tensor/CMakeLists.txt index 9c30275220..3dcecf64e1 100644 --- a/host/host_tensor/CMakeLists.txt +++ b/host/host_tensor/CMakeLists.txt @@ -10,6 +10,8 @@ set(HOST_TENSOR_SOURCE ## the library target add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE}) +target_include_directories(host_tensor SYSTEM PUBLIC $) + target_link_libraries(host_tensor PRIVATE hip::device) target_link_libraries(host_tensor INTERFACE hip::host) diff --git a/host/host_tensor/include/conv_common.hpp b/host/host_tensor/include/conv_common.hpp index 73126b3c79..4bf2c23494 100644 --- a/host/host_tensor/include/conv_common.hpp +++ b/host/host_tensor/include/conv_common.hpp @@ -1,7 +1,7 @@ #ifndef CONV_COMMON_HPP #define CONV_COMMON_HPP -#include "dynamic_tensor_descriptor.hpp" +#include "tensor_descriptor.hpp" enum ConvTensorLayout { @@ -19,8 +19,8 @@ template constexpr auto get_convolution_output_default_4d_tensor_descriptor( - const ck::DynamicTensorDescriptor& in_desc, - const ck::DynamicTensorDescriptor& wei_desc, + const ck::TensorDescriptor& in_desc, + const ck::TensorDescriptor& wei_desc, const ConvStrides& conv_strides, const ConvDilations conv_dilations, const LeftPads& left_pads, @@ -57,12 +57,12 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor( const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1; const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1; - return make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); + return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo)); } template constexpr std::size_t -calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc) +calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc) { using namespace ck; diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp index 2299e14921..e2cba94100 100644 --- a/host/host_tensor/include/device.hpp +++ b/host/host_tensor/include/device.hpp @@ -34,24 +34,16 @@ struct KernelTimer using device_stream_t = hipStream_t; template -void launch_kernel(F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - hipStream_t stream_id, - Args... args) +void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { + hipStream_t stream_id = nullptr; + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); } template -float launch_and_time_kernel(F kernel, - int nrepeat, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - hipStream_t stream_id, - Args... args) +float launch_and_time_kernel( + F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { KernelTimer timer; @@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel, printf("Warm up\n"); + hipStream_t stream_id = nullptr; + // warm up hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); diff --git a/host/host_tensor/include/host_conv.hpp b/host/host_tensor/include/host_conv.hpp index 7f26cb42f7..c1228f4832 100644 --- a/host/host_tensor/include/host_conv.hpp +++ b/host/host_tensor/include/host_conv.hpp @@ -14,15 +14,13 @@ void host_direct_convolution(const Tensor& in, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, + const InRightPads&, const ConvTensorLayout layout = ConvTensorLayout::NCHW) { using namespace ck; constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { double v = 0; @@ -68,23 +66,25 @@ void host_direct_convolution(const Tensor& in, out(n, ho, wo, k) = v; }; - switch(layout) + if(layout == ConvTensorLayout::NCHW) { - case ConvTensorLayout::NCHW: make_ParallelTensorFunctor(f_nchw, out.mDesc.GetLengths()[0], out.mDesc.GetLengths()[1], out.mDesc.GetLengths()[2], out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - break; - case ConvTensorLayout::NHWC: + } + else if(layout == ConvTensorLayout::NHWC) + { make_ParallelTensorFunctor(f_nhwc, out.mDesc.GetLengths()[0], out.mDesc.GetLengths()[1], out.mDesc.GetLengths()[2], out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - break; - default: throw std::runtime_error("wrong! not supported layout"); + } + else + { + throw std::runtime_error("wrong! not supported layout"); } } @@ -100,17 +100,15 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, constexpr std::size_t HoPerTile = 2; constexpr std::size_t WoPerTile = 2; - std::size_t N = in_nchw.mDesc.GetLengths()[0]; - std::size_t C = in_nchw.mDesc.GetLengths()[1]; - std::size_t HI = in_nchw.mDesc.GetLengths()[2]; - std::size_t WI = in_nchw.mDesc.GetLengths()[3]; + std::size_t N = in_nchw.mDesc.GetLengths()[0]; + std::size_t C = in_nchw.mDesc.GetLengths()[1]; std::size_t K = wei_kcyx.mDesc.GetLengths()[0]; std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; - std::size_t HO = out_nkhw.mDesc.GetLengths()[2]; - std::size_t WO = out_nkhw.mDesc.GetLengths()[3]; + std::size_t Ho = out_nkhw.mDesc.GetLengths()[2]; + std::size_t Wo = out_nkhw.mDesc.GetLengths()[3]; index_t h_pad_low = InLeftPads{}.Get(Number<0>{}); index_t w_pad_low = InLeftPads{}.Get(Number<1>{}); @@ -118,8 +116,8 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, std::size_t HiPerTile = HoPerTile + Y - 1; std::size_t WiPerTile = WoPerTile + X - 1; - std::size_t HTile = (HO + HoPerTile - 1) / HoPerTile; - std::size_t WTile = (WO + WoPerTile - 1) / WoPerTile; + std::size_t HTile = (Ho + HoPerTile - 1) / HoPerTile; + std::size_t WTile = (Wo + WoPerTile - 1) / WoPerTile; Tensor in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile}); Tensor in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile}); diff --git a/host/host_tensor/include/host_conv_bwd_data.hpp b/host/host_tensor/include/host_conv_bwd_data.hpp index 07617c3926..ca23422e23 100644 --- a/host/host_tensor/include/host_conv_bwd_data.hpp +++ b/host/host_tensor/include/host_conv_bwd_data.hpp @@ -14,7 +14,7 @@ void host_direct_convolution_backward_data(Tensor& in, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, + const InRightPads& /* in_right_pads */, const ConvTensorLayout layout = ConvTensorLayout::NCHW) { using namespace ck; @@ -25,11 +25,6 @@ void host_direct_convolution_backward_data(Tensor& in, constexpr auto I3 = Number<3>{}; auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { - std::size_t N = in.mDesc.GetLengths()[I0]; - std::size_t C = in.mDesc.GetLengths()[I1]; - std::size_t Hi = in.mDesc.GetLengths()[I2]; - std::size_t Wi = in.mDesc.GetLengths()[I3]; - std::size_t K = wei.mDesc.GetLengths()[I0]; std::size_t Y = wei.mDesc.GetLengths()[I2]; std::size_t X = wei.mDesc.GetLengths()[I3]; @@ -74,11 +69,6 @@ void host_direct_convolution_backward_data(Tensor& in, }; auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { - std::size_t N = in.mDesc.GetLengths()[I0]; - std::size_t Hi = in.mDesc.GetLengths()[I1]; - std::size_t Wi = in.mDesc.GetLengths()[I2]; - std::size_t C = in.mDesc.GetLengths()[I3]; - std::size_t K = wei.mDesc.GetLengths()[I0]; std::size_t Y = wei.mDesc.GetLengths()[I1]; std::size_t X = wei.mDesc.GetLengths()[I2]; @@ -122,22 +112,24 @@ void host_direct_convolution_backward_data(Tensor& in, in(n, hi, wi, c) = v; }; - switch(layout) + if(layout == ConvTensorLayout::NCHW) { - case ConvTensorLayout::NCHW: make_ParallelTensorFunctor(f_nchw, in.mDesc.GetLengths()[0], in.mDesc.GetLengths()[1], in.mDesc.GetLengths()[2], in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - break; - case ConvTensorLayout::NHWC: + } + else if(layout == ConvTensorLayout::NHWC) + { make_ParallelTensorFunctor(f_nhwc, in.mDesc.GetLengths()[0], in.mDesc.GetLengths()[1], in.mDesc.GetLengths()[2], in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - break; - default: throw std::runtime_error("wrong! not supported layout"); + } + else + { + throw std::runtime_error("wrong! not supported layout"); } } diff --git a/host/host_tensor/include/host_tensor.hpp b/host/host_tensor/include/host_tensor.hpp index 70778a4a94..06aed0a0c1 100644 --- a/host/host_tensor/include/host_tensor.hpp +++ b/host/host_tensor/include/host_tensor.hpp @@ -34,7 +34,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) first = false; else os << delim; - os << T{v}; + os << static_cast(v); } return os; } diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp index 98192e066f..7c09843d01 100644 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -9,7 +9,7 @@ struct GeneratorTensor_1 int value = 1; template - float operator()(Is... is) + float operator()(Is...) { return value; } diff --git a/host/host_tensor/src/device.cpp b/host/host_tensor/src/device.cpp index d0d74a4c2a..0d1b3d6883 100644 --- a/host/host_tensor/src/device.cpp +++ b/host/host_tensor/src/device.cpp @@ -24,32 +24,32 @@ struct KernelTimerImpl { KernelTimerImpl() { - hipEventCreate(&mStart); - hipEventCreate(&mEnd); + hipGetErrorString(hipEventCreate(&mStart)); + hipGetErrorString(hipEventCreate(&mEnd)); } ~KernelTimerImpl() { - hipEventDestroy(mStart); - hipEventDestroy(mEnd); + hipGetErrorString(hipEventDestroy(mStart)); + hipGetErrorString(hipEventDestroy(mEnd)); } void Start() { - hipDeviceSynchronize(); - hipEventRecord(mStart, 0); + hipGetErrorString(hipDeviceSynchronize()); + hipGetErrorString(hipEventRecord(mStart, nullptr)); } void End() { - hipEventRecord(mEnd, 0); - hipEventSynchronize(mEnd); + hipGetErrorString(hipEventRecord(mEnd, nullptr)); + hipGetErrorString(hipEventSynchronize(mEnd)); } float GetElapsedTime() const { float time; - hipEventElapsedTime(&time, mStart, mEnd); + hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd)); return time; } diff --git a/host/online_compile/CMakeLists.txt b/host/online_compile/CMakeLists.txt deleted file mode 100644 index 1b66703fcd..0000000000 --- a/host/online_compile/CMakeLists.txt +++ /dev/null @@ -1,168 +0,0 @@ -set(CMAKE_CXX_COMPILER /opt/rocm/llvm/bin/clang++) - -## for online-compiling of HIP kernels -set(OLC_HIP_COMPILER ${CMAKE_CXX_COMPILER} CACHE PATH "") - -## reset to avoid the C++ options from the parent project -set(CMAKE_CXX_FLAGS "") -message("Compiling options for library and kernels: ${CMAKE_CXX_FLAGS}") - -# look for and register clang-offload-bundler -if(OLC_HIP_COMPILER MATCHES ".*clang\\+\\+$") - find_program(OLC_OFFLOADBUNDLER_BIN clang-offload-bundler - PATH_SUFFIXES bin - PATHS - /opt/rocm/llvm - ${CMAKE_INSTALL_PREFIX}/llvm - ) -endif() - -if(OLC_OFFLOADBUNDLER_BIN) - message(STATUS "clang-offload-bundler found: ${OLC_OFFLOADBUNDLER_BIN}") - set(OLC_OFFLOADBUNDLER_BIN "${OLC_OFFLOADBUNDLER_BIN}") -else() - # look for and register extractkernel - message(STATUS "clang-offload-bundler not found") - - find_program(EXTRACTKERNEL_BIN extractkernel - PATH_SUFFIXES bin - PATHS - /opt/rocm/hip - /opt/rocm/hcc - /opt/rocm - ${CMAKE_INSTALL_PREFIX}/hip - ${CMAKE_INSTALL_PREFIX}/hcc - ${CMAKE_INSTALL_PREFIX} - - ) - if(EXTRACTKERNEL_BIN) - message(STATUS "extractkernel found: ${EXTRACTKERNEL_BIN}") - set(EXTRACTKERNEL_BIN "${EXTRACTKERNEL_BIN}") - else() - message(FATAL_ERROR "extractkernel not found") - endif() -endif() - -option(Boost_USE_STATIC_LIBS "Use boost static libraries" OFF) -set(BOOST_COMPONENTS filesystem) -add_definitions(-DBOOST_ALL_NO_LIB=1) -find_package(Boost REQUIRED COMPONENTS ${BOOST_COMPONENTS}) - -# HIP is always required -find_package(hip REQUIRED PATHS /opt/rocm) -message(STATUS "Build with HIP ${hip_VERSION}") -target_flags(HIP_COMPILER_FLAGS hip::device) -# Remove cuda arch flags -string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") -string(REGEX REPLACE --offload-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") - -set(OLC_hip_VERSION_MAJOR "${hip_VERSION_MAJOR}") -set(OLC_hip_VERSION_MINOR "${hip_VERSION_MINOR}") -set(OLC_hip_VERSION_PATCH "${hip_VERSION_PATCH}") - -option(ENABLE_DEBUG "Build to enable debugging" ON) -if(ENABLE_DEBUG) - set(OLC_DEBUG 1) -else() - set(OLC_DEBUG 0) -endif() - -configure_file("${PROJECT_SOURCE_DIR}/host/online_compile/include/config.h.in" "${PROJECT_BINARY_DIR}/host/online_compile/include/config.h") - -include_directories(BEFORE - ${PROJECT_BINARY_DIR}/host/online_compile/include -) - -message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") - -## HIP_COMPILER_FLAGS will be used for on-line compiling of the HIP kernels -set(HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS} ${HIP_ONLINE_COMPILER_FLAGS}") -add_definitions("-DHIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}") - -file(GLOB_RECURSE COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/*/*.hpp") -file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp") -set(MCONV_KERNEL_INCLUDES - ${COMPOSABLE_KERNEL_INCLUDE_1} - ${COMPOSABLE_KERNEL_INCLUDE_2} - ) - -file(GLOB_RECURSE MCONV_KERNELS "${PROJECT_SOURCE_DIR}/composable_kernel/src/kernel_wrapper/*.cpp") - -add_kernels(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNELS}") -add_kernel_includes(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNEL_INCLUDES}") - -set(ONLINE_COMPILATION_SOURCE - ${PROJECT_BINARY_DIR}/kernel.cpp - ${PROJECT_BINARY_DIR}/kernel_includes.cpp -) - -include_directories(BEFORE - ${PROJECT_BINARY_DIR}/host/online_compile/include - include -) - -set(OLC_HIP_UTILITY_CPPS - hip_utility/logger.cpp - hip_utility/tmp_dir.cpp - hip_utility/md5.cpp - hip_utility/exec_utils.cpp - hip_utility/target_properties.cpp - hip_utility/handlehip.cpp - hip_utility/kernel_build_params.cpp - hip_utility/hip_build_utils.cpp - hip_utility/hipoc_program.cpp - hip_utility/hipoc_kernel.cpp - hip_utility/kernel_cache.cpp - hip_utility/binary_cache.cpp - ) - -list(APPEND OLC_SOURCES ${OLC_HIP_UTILITY_CPPS} ${OLC_HIP_UTILITY_HEADERS}) - -## addkernels provide the tool to create inlined kernels in one header -add_subdirectory(addkernels) - -function(inline_kernels_src KERNELS KERNEL_INCLUDES) - set(KERNEL_SRC_HPP_FILENAME batch_all.cpp.hpp) - set(KERNEL_SRC_HPP_PATH ${PROJECT_BINARY_DIR}/inlined_kernels/${KERNEL_SRC_HPP_FILENAME}) - set(KERNEL_SRC_CPP_PATH ${PROJECT_BINARY_DIR}/inlined_kernels/batch_all.cpp) - - add_custom_command( - OUTPUT ${KERNEL_SRC_HPP_PATH} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - DEPENDS addkernels ${KERNELS} ${KERNEL_INCLUDES} - COMMAND $ -target ${KERNEL_SRC_HPP_PATH} -extern -source ${KERNELS} - COMMENT "Inlining All kernels" - ) - configure_file(kernels_batch.cpp.in ${KERNEL_SRC_CPP_PATH}) - list(APPEND OLC_SOURCES ${KERNEL_SRC_CPP_PATH} ${KERNEL_SRC_HPP_PATH}) - - set(OLC_SOURCES ${OLC_SOURCES} PARENT_SCOPE) -endfunction() - -inline_kernels_src("${MCONV_KERNELS}" "${MCONV_KERNEL_INCLUDES}") - -list(APPEND ONLINE_COMPILATION_SOURCE ${OLC_SOURCES} ${PROJECT_BINARY_DIR}/olc_kernel_includes.h) - -add_custom_command( - OUTPUT ${PROJECT_BINARY_DIR}/olc_kernel_includes.h - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - DEPENDS addkernels ${MCONV_KERNEL_INCLUDES} - COMMAND $ -no-recurse -guard GUARD_OLC_KERNEL_INCLUDES_HPP_ -target ${PROJECT_BINARY_DIR}/olc_kernel_includes.h -source ${MCONV_KERNEL_INCLUDES} - COMMENT "Inlining HIP kernel includes" - ) - -## the library target -add_library(online_compile SHARED ${ONLINE_COMPILATION_SOURCE}) - -target_include_directories(online_compile PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/online_compile/include/) -target_include_directories(online_compile PRIVATE ${PROJECT_BINARY_DIR}) -target_include_directories(online_compile PRIVATE ${PROJECT_SOURCE_DIR}/external/half/include/) - -target_link_libraries(online_compile PRIVATE hip::device) -target_link_libraries(online_compile INTERFACE hip::host) -target_link_libraries(online_compile PRIVATE Boost::filesystem) - -target_compile_features(online_compile PUBLIC) -set_target_properties(online_compile PROPERTIES POSITION_INDEPENDENT_CODE ON) - -install(TARGETS online_compile LIBRARY DESTINATION lib) diff --git a/host/online_compile/addkernels/addkernels.cpp b/host/online_compile/addkernels/addkernels.cpp deleted file mode 100644 index 5be523d97b..0000000000 --- a/host/online_compile/addkernels/addkernels.cpp +++ /dev/null @@ -1,264 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#include "include_inliner.hpp" -#include -#include -#include -#include -#include -#include -#include - -void Bin2Hex(std::istream& source, - std::ostream& target, - const std::string& variable, - bool nullTerminate, - size_t bufferSize, - size_t lineSize) -{ - source.seekg(0, std::ios::end); - std::unique_ptr buffer(new unsigned char[bufferSize]); - std::streamoff sourceSize = source.tellg(); - std::streamoff blockStart = 0; - - if(variable.length() != 0) - { - target << "extern const size_t " << variable << "_SIZE;" << std::endl; - target << "extern const unsigned char " << variable << "[];" << std::endl; - target << "const size_t " << variable << "_SIZE = " << std::setbase(10) << sourceSize << ";" - << std::endl; - target << "const unsigned char " << variable << "[] = {" << std::endl; - } - - target << std::setbase(16) << std::setfill('0'); - source.seekg(0, std::ios::beg); - - while(blockStart < sourceSize) - { - source.read(reinterpret_cast(buffer.get()), bufferSize); - - std::streamoff pos = source.tellg(); - std::streamoff blockSize = (pos < 0 ? sourceSize : pos) - blockStart; - std::streamoff i = 0; - - while(i < blockSize) - { - size_t j = i; - size_t end = std::min(i + lineSize, blockSize); - - for(; j < end; j++) - target << "0x" << std::setw(2) << static_cast(buffer[j]) << ","; - - target << std::endl; - i = end; - } - - blockStart += blockSize; - } - - if(nullTerminate) - target << "0x00," << std::endl; - - if(variable.length() != 0) - { - target << "};" << std::endl; - } -} - -void PrintHelp() -{ - std::cout << "Usage: bin2hex {