commit 6fe3627a9eb35f1237266f1b6cc8fd3456aed67d Author: Chao Liu Date: Thu Aug 19 10:55:03 2021 -0500 Composable kernel init integration v3 (#1097) * Squashed 'src/composable_kernel/' content from commit f6edda611 git-subtree-dir: src/composable_kernel git-subtree-split: f6edda6119ebbb237dfa6270797b34f960d7b190 * add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files * Squashed 'src/composable_kernel/' changes from f6edda611..5781adf5c 5781adf5c Update develop (#5) (#6) 97e6d514f Merge pull request #4 from ROCmSoftwarePlatform/separate_online_compile 7b1ec41e5 refactor 49c33aaea refactor 54b3e73d1 rename git-subtree-dir: src/composable_kernel git-subtree-split: 5781adf5cf4ac753e2e36da7385791775b744bf7 * fix * refactor * remove online compilation from CK * refactor * fix * add ctest * add c-style pointer cast * vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast * fix clang warning suppression * tidy * suppress cppcheck * fix enum issue * revert chagnes to hip build * fix kernel filename * update CK build script * rename * rename * make innner product compatiable on gfx900 * Update src/include/miopen/solver/ck_utility_common.hpp Co-authored-by: JD * compiler parameter use stream * use int instead of index_t in kernel wrapper * DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element * refactor * refactor * change cmakelist * change ck common utility * fix Co-authored-by: JD diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..22f2674966 --- /dev/null +++ b/.clang-format @@ -0,0 +1,90 @@ +--- +Language: Cpp +AccessModifierOffset: 0 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: true +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: true +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: false + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + BeforeCatch: true + BeforeElse: true + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +ColumnLimit: 100 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ] +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IndentCaseLabels: false +IndentWidth: 4 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 60 +PointerAlignment: Left +ReflowComments: true +SortIncludes: false +SpaceAfterCStyleCast: false +# SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: Never +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... + diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 0000000000..5c2b781687 --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,3 @@ +CheckOptions: + - key: bugprone-reserved-identifier.AllowedIdentifiers + value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__' diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000..306e6ca649 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,198 @@ +cmake_minimum_required(VERSION 3.5) +project(composable_kernel) + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") + +include(CheckCXXCompilerFlag) + +## C++ +enable_language(CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") + +## OpenMP +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + # workaround issue hipcc in rocm3.5 cannot find openmp + set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") + set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument") + set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5") + set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) + set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) + set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES}) +else() + find_package(OpenMP REQUIRED) +endif() + +message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") +message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") +message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") +message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") + +set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +link_libraries(${OpenMP_gomp_LIBRARY}) +link_libraries(${OpenMP_pthread_LIBRARY}) + +## HIP +find_package(HIP REQUIRED) +message(STATUS "Build with HIP ${hip_VERSION}") + +## half +#find_path(HALF_INCLUDE_DIR half.hpp) +message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") + +# CMAKE_CXX_FLAGS +SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") +if(BUILD_DEV) + string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything") +endif() +message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") + +## tidy +include(EnableCompilerWarnings) +set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) +if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") + set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) +# Enable tidy on hip +elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU") + set(MIOPEN_TIDY_ERRORS ALL) +endif() + +include(ClangTidy) +enable_clang_tidy( + CHECKS + * + -abseil-* + -android-cloexec-fopen + # Yea we shouldn't be using rand() + -cert-msc30-c + -bugprone-exception-escape + -bugprone-macro-parentheses + -cert-env33-c + -cert-msc32-c + -cert-msc50-cpp + -cert-msc51-cpp + -cert-dcl37-c + -cert-dcl51-cpp + -clang-analyzer-alpha.core.CastToStruct + -clang-analyzer-optin.performance.Padding + -clang-diagnostic-deprecated-declarations + -clang-diagnostic-extern-c-compat + -clang-diagnostic-unused-command-line-argument + -cppcoreguidelines-avoid-c-arrays + -cppcoreguidelines-avoid-magic-numbers + -cppcoreguidelines-explicit-virtual-functions + -cppcoreguidelines-init-variables + -cppcoreguidelines-macro-usage + -cppcoreguidelines-non-private-member-variables-in-classes + -cppcoreguidelines-pro-bounds-array-to-pointer-decay + -cppcoreguidelines-pro-bounds-constant-array-index + -cppcoreguidelines-pro-bounds-pointer-arithmetic + -cppcoreguidelines-pro-type-member-init + -cppcoreguidelines-pro-type-reinterpret-cast + -cppcoreguidelines-pro-type-union-access + -cppcoreguidelines-pro-type-vararg + -cppcoreguidelines-special-member-functions + -fuchsia-* + -google-explicit-constructor + -google-readability-braces-around-statements + -google-readability-todo + -google-runtime-int + -google-runtime-references + -hicpp-vararg + -hicpp-braces-around-statements + -hicpp-explicit-conversions + -hicpp-named-parameter + -hicpp-no-array-decay + # We really shouldn't use bitwise operators with signed integers, but + # opencl leaves us no choice + -hicpp-avoid-c-arrays + -hicpp-signed-bitwise + -hicpp-special-member-functions + -hicpp-uppercase-literal-suffix + -hicpp-use-auto + -hicpp-use-equals-default + -hicpp-use-override + -llvm-header-guard + -llvm-include-order + #-llvmlibc-* + -llvmlibc-restrict-system-libc-headers + -llvmlibc-callee-namespace + -llvmlibc-implementation-in-namespace + -llvm-else-after-return + -llvm-qualified-auto + -misc-misplaced-const + -misc-non-private-member-variables-in-classes + -misc-no-recursion + -modernize-avoid-bind + -modernize-avoid-c-arrays + -modernize-pass-by-value + -modernize-use-auto + -modernize-use-default-member-init + -modernize-use-equals-default + -modernize-use-trailing-return-type + -modernize-use-transparent-functors + -performance-unnecessary-value-param + -readability-braces-around-statements + -readability-else-after-return + # we are not ready to use it, but very useful + -readability-function-cognitive-complexity + -readability-isolate-declaration + -readability-magic-numbers + -readability-named-parameter + -readability-uppercase-literal-suffix + -readability-convert-member-functions-to-static + -readability-qualified-auto + -readability-redundant-string-init + # too many narrowing conversions in our code + -bugprone-narrowing-conversions + -cppcoreguidelines-narrowing-conversions + -altera-struct-pack-align + -cppcoreguidelines-prefer-member-initializer + + ${MIOPEN_TIDY_CHECKS} + ${MIOPEN_TIDY_ERRORS} + HEADER_FILTER + "\.hpp$" + EXTRA_ARGS + -DMIOPEN_USE_CLANG_TIDY +) + +include(CppCheck) +enable_cppcheck( + CHECKS + warning + style + performance + portability + SUPPRESS + ConfigurationNotChecked + constStatement + duplicateCondition + noExplicitConstructor + passedByValue + preprocessorErrorDirective + shadowVariable + unusedFunction + unusedPrivateFunction + unusedStructMember + unmatchedSuppression + FORCE + SOURCES + host/host_tensor/src + host/driver_offline/src + composable_kernel/src/kernel_wrapper + INCLUDE + host/host_tensor/include + host/solver/include + host/driver_offline/include + composable_kernel/include/* + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/include + DEFINE + CPPCHECK=1 + __linux__=1 +) + +add_subdirectory(host) diff --git a/README.md b/README.md new file mode 100644 index 0000000000..4f071d5896 --- /dev/null +++ b/README.md @@ -0,0 +1,177 @@ +# How to build and run + +# Docker +``` +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.2-tf2.4-dev \ +/bin/bash +``` + +# Install Boost for online compilation +https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install + + +# Build +Add path of Boost +``` + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH +``` + +``` +mkdir build && cd build +``` + +cmake cmd. Need to Specify target ID, example below is gfx908 +``` +cmake \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +.. +``` + +Build drivers: \ +``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \ +``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \ +``conv_fwd_driver_online`` is (online compilation) driver for forward convolution +``` + make -j conv_fwd_driver_offline + make -j conv_bwd_driver_offline + make -j conv_fwd_driver_online +``` + +# Run +* layout: 0 = NCHW; 1 = NHWC +* algo: algorithm +* verify: 0 = no verification; 1 = do verification +* init: 0 ~ 5. initialization method +* log: 0 = no log; 1 = do log +* repeat: number of time kernel being launched +``` +######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads + ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 + ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 + ./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 +``` + +# Result +Forward convoltuion, FP16, NCHW +``` +./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + +layout: 0 +in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1} +wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1} +out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {2, 2, } +ConvDilations size 2, {1, 1, } +device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw +a_k0_m_k1_grid_desc{216, 256, 8} +b_k0_n_k1_grid_desc{216, 165888, 8} +c_m_n_grid_desc{ 256, 165888} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 1.4155 ms, 103.686 TFlop/s +``` + +Forward convoltuion, FP16, NCHW +``` + ./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 + + layout: 0 +in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1} +wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1} +out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {1, 1, } +ConvDilations size 2, {1, 1, } +device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw +a_k0_m_k1_grid_desc{288, 1024, 8} +b_k0_n_k1_grid_desc{288, 50176, 8} +c_m_n_grid_desc{ 1024, 50176} +launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 2.21357 ms, 106.959 TFlop/s + ``` + + Forward convolution, FP16, NHWC + ``` + ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + + layout: 1 +in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1} +wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1} +out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {2, 2, } +ConvDilations size 2, {1, 1, } +device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk +a_k0_m_k1_grid_desc{216, 165888, 8} +b_k0_n_k1_grid_desc{216, 256, 8} +c_m_n_grid_desc{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 1.12014 ms, 131.025 TFlop/s + ``` + + Forward convolution, FP16, NHWC + ``` + ./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1 + + layout: 1 +in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} +wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1} +out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {1, 1, } +ConvDilations size 2, {1, 1, } +device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk +a_k0_m_k1_grid_desc{288, 50176, 8} +b_k0_n_k1_grid_desc{288, 1024, 8} +c_m_n_grid_desc{ 50176, 1024} +launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 1.86877 ms, 126.693 TFlop/s + ``` + + Backward data convolution, FP16, NHWC + ``` + ./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1 + + layout: 1 +in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1} +wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1} +out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1} +InLeftPads size 2, {1, 1, } +InRightPads size 2, {1, 1, } +ConvStrides size 2, {1, 1, } +ConvDilations size 2, {1, 1, } +device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk +a_k0_m_k1_grid_desc{288, 50176, 8} +b_k0_n_k1_grid_desc{288, 1024, 8} +c_m_n_grid_desc{ 50176, 1024} +launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 1 times... +Average time : 2.22461 ms, 106.428 TFlop/s +``` diff --git a/cmake/Analyzers.cmake b/cmake/Analyzers.cmake new file mode 100644 index 0000000000..1bf1a52c68 --- /dev/null +++ b/cmake/Analyzers.cmake @@ -0,0 +1,34 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ + +if(NOT TARGET analyze) + add_custom_target(analyze) +endif() + +function(mark_as_analyzer) + add_dependencies(analyze ${ARGN}) +endfunction() + diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake new file mode 100644 index 0000000000..01b348c458 --- /dev/null +++ b/cmake/ClangTidy.cmake @@ -0,0 +1,162 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +include(CMakeParseArguments) +include(Analyzers) + +get_filename_component(CLANG_TIDY_EXE_HINT "${CMAKE_CXX_COMPILER}" PATH) + +find_program(CLANG_TIDY_EXE + NAMES + clang-tidy + clang-tidy-5.0 + clang-tidy-4.0 + clang-tidy-3.9 + clang-tidy-3.8 + clang-tidy-3.7 + clang-tidy-3.6 + clang-tidy-3.5 + HINTS + ${CLANG_TIDY_EXE_HINT} + PATH_SUFFIXES + compiler/bin + PATHS + /opt/rocm/llvm/bin + /opt/rocm/hcc + /usr/local/opt/llvm/bin +) + +function(find_clang_tidy_version VAR) + execute_process(COMMAND ${CLANG_TIDY_EXE} -version OUTPUT_VARIABLE VERSION_OUTPUT) + separate_arguments(VERSION_OUTPUT_LIST UNIX_COMMAND "${VERSION_OUTPUT}") + list(FIND VERSION_OUTPUT_LIST "version" VERSION_INDEX) + if(VERSION_INDEX GREATER 0) + math(EXPR VERSION_INDEX "${VERSION_INDEX} + 1") + list(GET VERSION_OUTPUT_LIST ${VERSION_INDEX} VERSION) + set(${VAR} ${VERSION} PARENT_SCOPE) + else() + set(${VAR} "0.0" PARENT_SCOPE) + endif() + +endfunction() + +if( NOT CLANG_TIDY_EXE ) + message( STATUS "Clang tidy not found" ) + set(CLANG_TIDY_VERSION "0.0") +else() + find_clang_tidy_version(CLANG_TIDY_VERSION) + message( STATUS "Clang tidy found: ${CLANG_TIDY_VERSION}") +endif() + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CLANG_TIDY_FIXIT_DIR ${CMAKE_BINARY_DIR}/fixits) +file(MAKE_DIRECTORY ${CLANG_TIDY_FIXIT_DIR}) +set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CLANG_TIDY_FIXIT_DIR}) + +macro(enable_clang_tidy) + set(options ANALYZE_TEMPORARY_DTORS ALL) + set(oneValueArgs HEADER_FILTER) + set(multiValueArgs CHECKS ERRORS EXTRA_ARGS) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + string(REPLACE ";" "," CLANG_TIDY_CHECKS "${PARSE_CHECKS}") + string(REPLACE ";" "," CLANG_TIDY_ERRORS "${PARSE_ERRORS}") + set(CLANG_TIDY_EXTRA_ARGS) + foreach(ARG ${PARSE_EXTRA_ARGS}) + list(APPEND CLANG_TIDY_EXTRA_ARGS "-extra-arg=${ARG}") + endforeach() + + set(CLANG_TIDY_ALL) + if(PARSE_ALL) + set(CLANG_TIDY_ALL ALL) + endif() + + message(STATUS "Clang tidy checks: ${CLANG_TIDY_CHECKS}") + + if (${PARSE_ANALYZE_TEMPORARY_DTORS}) + set(CLANG_TIDY_ANALYZE_TEMPORARY_DTORS "-analyze-temporary-dtors") + endif() + + if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0") + set(CLANG_TIDY_ERRORS_ARG "") + else() + set(CLANG_TIDY_ERRORS_ARG "-warnings-as-errors='${CLANG_TIDY_ERRORS}'") + endif() + + if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0") + set(CLANG_TIDY_QUIET_ARG "") + else() + set(CLANG_TIDY_QUIET_ARG "-quiet") + endif() + + if(PARSE_HEADER_FILTER) + string(REPLACE "$" "$$" CLANG_TIDY_HEADER_FILTER "${PARSE_HEADER_FILTER}") + else() + set(CLANG_TIDY_HEADER_FILTER ".*") + endif() + + set(CLANG_TIDY_COMMAND + ${CLANG_TIDY_EXE} + ${CLANG_TIDY_QUIET_ARG} + -p ${CMAKE_BINARY_DIR} + -checks='${CLANG_TIDY_CHECKS}' + ${CLANG_TIDY_ERRORS_ARG} + ${CLANG_TIDY_EXTRA_ARGS} + ${CLANG_TIDY_ANALYZE_TEMPORARY_DTORS} + -header-filter='${CLANG_TIDY_HEADER_FILTER}' + ) + add_custom_target(tidy ${CLANG_TIDY_ALL}) + mark_as_analyzer(tidy) + add_custom_target(tidy-base) + add_custom_target(tidy-make-fixit-dir COMMAND ${CMAKE_COMMAND} -E make_directory ${CLANG_TIDY_FIXIT_DIR}) + add_custom_target(tidy-rm-fixit-dir COMMAND ${CMAKE_COMMAND} -E remove_directory ${CLANG_TIDY_FIXIT_DIR}) + add_dependencies(tidy-make-fixit-dir tidy-rm-fixit-dir) + add_dependencies(tidy-base tidy-make-fixit-dir) +endmacro() + +function(clang_tidy_check TARGET) + get_target_property(SOURCES ${TARGET} SOURCES) + # TODO: Use generator expressions instead + # COMMAND ${CLANG_TIDY_COMMAND} $ + # COMMAND ${CLANG_TIDY_COMMAND} $, > + foreach(SOURCE ${SOURCES}) + if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS")) + string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file) + set(tidy_target tidy-target-${TARGET}-${tidy_file}) + add_custom_target(${tidy_target} + # for some targets clang-tidy not able to get information from .clang-tidy + DEPENDS ${SOURCE} + COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..." + ) + add_dependencies(${tidy_target} ${TARGET}) + add_dependencies(${tidy_target} tidy-base) + add_dependencies(tidy ${tidy_target}) + endif() + endforeach() +endfunction() + diff --git a/cmake/CppCheck.cmake b/cmake/CppCheck.cmake new file mode 100644 index 0000000000..797dcf4b4d --- /dev/null +++ b/cmake/CppCheck.cmake @@ -0,0 +1,130 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ + +include(CMakeParseArguments) +include(ProcessorCount) +include(Analyzers) + +find_program(CPPCHECK_EXE + NAMES + cppcheck + PATHS + /opt/rocm/bin +) + +ProcessorCount(CPPCHECK_JOBS) + +set(CPPCHECK_BUILD_DIR ${CMAKE_BINARY_DIR}/cppcheck-build) +file(MAKE_DIRECTORY ${CPPCHECK_BUILD_DIR}) +set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CPPCHECK_BUILD_DIR}) + +macro(enable_cppcheck) + set(options FORCE) + set(oneValueArgs) + set(multiValueArgs CHECKS SUPPRESS DEFINE UNDEFINE INCLUDE SOURCES) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + string(REPLACE ";" "," CPPCHECK_CHECKS "${PARSE_CHECKS}") + string(REPLACE ";" "\n" CPPCHECK_SUPPRESS "${PARSE_SUPPRESS};*:/usr/*") + file(WRITE ${CMAKE_BINARY_DIR}/cppcheck-supressions "${CPPCHECK_SUPPRESS}") + set(CPPCHECK_DEFINES) + foreach(DEF ${PARSE_DEFINE}) + set(CPPCHECK_DEFINES "${CPPCHECK_DEFINES} -D${DEF}") + endforeach() + + set(CPPCHECK_UNDEFINES) + foreach(DEF ${PARSE_UNDEFINE}) + set(CPPCHECK_UNDEFINES "${CPPCHECK_UNDEFINES} -U${DEF}") + endforeach() + + set(CPPCHECK_INCLUDES) + foreach(INC ${PARSE_INCLUDE}) + set(CPPCHECK_INCLUDES "${CPPCHECK_INCLUDES} -I${INC}") + endforeach() + + # set(CPPCHECK_FORCE) + set(CPPCHECK_FORCE "--project=${CMAKE_BINARY_DIR}/compile_commands.json") + if(PARSE_FORCE) + set(CPPCHECK_FORCE --force) + endif() + + set(SOURCES) + set(GLOBS) + foreach(SOURCE ${PARSE_SOURCES}) + get_filename_component(ABS_SOURCE ${SOURCE} ABSOLUTE) + if(EXISTS ${ABS_SOURCE}) + if(IS_DIRECTORY ${ABS_SOURCE}) + set(GLOBS "${GLOBS} ${ABS_SOURCE}/*.cpp ${ABS_SOURCE}/*.hpp ${ABS_SOURCE}/*.cxx ${ABS_SOURCE}/*.c ${ABS_SOURCE}/*.h") + else() + set(SOURCES "${SOURCES} ${ABS_SOURCE}") + endif() + else() + set(GLOBS "${GLOBS} ${ABS_SOURCE}") + endif() + endforeach() + + file(WRITE ${CMAKE_BINARY_DIR}/cppcheck.cmake " + file(GLOB_RECURSE GSRCS ${GLOBS}) + set(CPPCHECK_COMMAND + ${CPPCHECK_EXE} + -q + # -v + # --report-progress + ${CPPCHECK_FORCE} + --cppcheck-build-dir=${CPPCHECK_BUILD_DIR} + --platform=native + --template=gcc + --error-exitcode=1 + -j ${CPPCHECK_JOBS} + ${CPPCHECK_DEFINES} + ${CPPCHECK_UNDEFINES} + ${CPPCHECK_INCLUDES} + --enable=${CPPCHECK_CHECKS} + --inline-suppr + --suppressions-list=${CMAKE_BINARY_DIR}/cppcheck-supressions + ${SOURCES} \${GSRCS} + ) + string(REPLACE \";\" \" \" CPPCHECK_SHOW_COMMAND \"\${CPPCHECK_COMMAND}\") + message(\"\${CPPCHECK_SHOW_COMMAND}\") + execute_process( + COMMAND \${CPPCHECK_COMMAND} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE RESULT + ) + if(NOT RESULT EQUAL 0) + message(FATAL_ERROR \"Cppcheck failed\") + endif() +") + + add_custom_target(cppcheck + COMMAND ${CMAKE_COMMAND} -P ${CMAKE_BINARY_DIR}/cppcheck.cmake + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "cppcheck: Running cppcheck..." + ) + mark_as_analyzer(cppcheck) +endmacro() + + diff --git a/cmake/DoxygenDoc.cmake b/cmake/DoxygenDoc.cmake new file mode 100644 index 0000000000..2e3669fcdf --- /dev/null +++ b/cmake/DoxygenDoc.cmake @@ -0,0 +1,355 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +include(CMakeParseArguments) +include(MainDoc) + +find_program(DOXYGEN_EXECUTABLE NAMES doxygen + PATH_SUFFIXES bin + DOC "Doxygen documentation generator" +) +mark_as_advanced(DOXYGEN_EXECUTABLE) + +find_path(DOT_EXECUTABLE NAMES dot + PATH_SUFFIXES bin + DOC "Graphviz" +) +mark_as_advanced(DOT_EXECUTABLE) + +set(DOXYGEN_ARGS +ABBREVIATE_BRIEF +ALIASES +ALLEXTERNALS +ALLOW_UNICODE_NAMES +ALPHABETICAL_INDEX +ALWAYS_DETAILED_SEC +AUTOLINK_SUPPORT +BINARY_TOC +BRIEF_MEMBER_DESC +BUILTIN_STL_SUPPORT +CALLER_GRAPH +CALL_GRAPH +CASE_SENSE_NAMES +CHM_FILE +CHM_INDEX_ENCODING +CITE_BIB_FILES +CLANG_ASSISTED_PARSING +CLANG_OPTIONS +CLASS_DIAGRAMS +CLASS_GRAPH +COLLABORATION_GRAPH +COLS_IN_ALPHA_INDEX +COMPACT_LATEX +COMPACT_RTF +CPP_CLI_SUPPORT +CREATE_SUBDIRS +DIAFILE_DIRS +DIA_PATH +DIRECTORY_GRAPH +DISABLE_INDEX +DISTRIBUTE_GROUP_DOC +DOCBOOK_OUTPUT +DOCBOOK_PROGRAMLISTING +DOCSET_BUNDLE_ID +DOCSET_FEEDNAME +DOCSET_PUBLISHER_ID +DOCSET_PUBLISHER_NAME +DOTFILE_DIRS +DOT_CLEANUP +DOT_FONTNAME +DOT_FONTPATH +DOT_FONTSIZE +DOT_GRAPH_MAX_NODES +DOT_IMAGE_FORMAT +DOT_MULTI_TARGETS +DOT_NUM_THREADS +# DOT_PATH +DOT_TRANSPARENT +DOXYFILE_ENCODING +ECLIPSE_DOC_ID +ENABLED_SECTIONS +ENABLE_PREPROCESSING +ENUM_VALUES_PER_LINE +EXAMPLE_PATH +EXAMPLE_PATTERNS +EXAMPLE_RECURSIVE +EXCLUDE +EXCLUDE_PATTERNS +EXCLUDE_SYMBOLS +EXCLUDE_SYMLINKS +EXPAND_AS_DEFINED +EXPAND_ONLY_PREDEF +EXTENSION_MAPPING +EXTERNAL_GROUPS +EXTERNAL_PAGES +EXTERNAL_SEARCH +EXTERNAL_SEARCH_ID +EXTRACT_ALL +EXTRACT_ANON_NSPACES +EXTRACT_LOCAL_CLASSES +EXTRACT_LOCAL_METHODS +EXTRACT_PACKAGE +EXTRACT_PRIVATE +EXTRACT_STATIC +EXTRA_PACKAGES +EXTRA_SEARCH_MAPPINGS +EXT_LINKS_IN_WINDOW +FILE_PATTERNS +FILE_VERSION_FILTER +FILTER_PATTERNS +FILTER_SOURCE_FILES +FILTER_SOURCE_PATTERNS +FORCE_LOCAL_INCLUDES +FORMULA_FONTSIZE +FORMULA_TRANSPARENT +FULL_PATH_NAMES +GENERATE_AUTOGEN_DEF +GENERATE_BUGLIST +GENERATE_CHI +GENERATE_DEPRECATEDLIST +GENERATE_DOCBOOK +GENERATE_DOCSET +GENERATE_ECLIPSEHELP +GENERATE_HTML +GENERATE_HTMLHELP +GENERATE_LATEX +GENERATE_LEGEND +GENERATE_MAN +GENERATE_PERLMOD +GENERATE_QHP +GENERATE_RTF +GENERATE_TAGFILE +GENERATE_TESTLIST +GENERATE_TODOLIST +GENERATE_TREEVIEW +GENERATE_XML +GRAPHICAL_HIERARCHY +GROUP_GRAPHS +GROUP_NESTED_COMPOUNDS +# HAVE_DOT +HHC_LOCATION +HIDE_COMPOUND_REFERENCE +HIDE_FRIEND_COMPOUNDS +HIDE_IN_BODY_DOCS +HIDE_SCOPE_NAMES +HIDE_UNDOC_CLASSES +HIDE_UNDOC_MEMBERS +HIDE_UNDOC_RELATIONS +HTML_COLORSTYLE_GAMMA +HTML_COLORSTYLE_HUE +HTML_COLORSTYLE_SAT +HTML_DYNAMIC_SECTIONS +HTML_EXTRA_FILES +HTML_EXTRA_STYLESHEET +HTML_FILE_EXTENSION +HTML_FOOTER +HTML_HEADER +HTML_INDEX_NUM_ENTRIES +HTML_OUTPUT +HTML_STYLESHEET +HTML_TIMESTAMP +IDL_PROPERTY_SUPPORT +IGNORE_PREFIX +IMAGE_PATH +INCLUDED_BY_GRAPH +INCLUDE_FILE_PATTERNS +INCLUDE_GRAPH +INCLUDE_PATH +INHERIT_DOCS +INLINE_GROUPED_CLASSES +INLINE_INFO +INLINE_INHERITED_MEMB +INLINE_SIMPLE_STRUCTS +INLINE_SOURCES +INPUT +INPUT_ENCODING +INPUT_FILTER +INTERACTIVE_SVG +INTERNAL_DOCS +JAVADOC_AUTOBRIEF +LATEX_BATCHMODE +LATEX_BIB_STYLE +LATEX_CMD_NAME +LATEX_EXTRA_FILES +LATEX_EXTRA_STYLESHEET +LATEX_FOOTER +LATEX_HEADER +LATEX_HIDE_INDICES +LATEX_OUTPUT +LATEX_SOURCE_CODE +LATEX_TIMESTAMP +LAYOUT_FILE +LOOKUP_CACHE_SIZE +MACRO_EXPANSION +MAKEINDEX_CMD_NAME +MAN_EXTENSION +MAN_LINKS +MAN_OUTPUT +MAN_SUBDIR +MARKDOWN_SUPPORT +MATHJAX_CODEFILE +MATHJAX_EXTENSIONS +MATHJAX_FORMAT +MATHJAX_RELPATH +MAX_DOT_GRAPH_DEPTH +MAX_INITIALIZER_LINES +MSCFILE_DIRS +MSCGEN_PATH +MULTILINE_CPP_IS_BRIEF +OPTIMIZE_FOR_FORTRAN +OPTIMIZE_OUTPUT_FOR_C +OPTIMIZE_OUTPUT_JAVA +OPTIMIZE_OUTPUT_VHDL +OUTPUT_DIRECTORY +OUTPUT_LANGUAGE +PAPER_TYPE +PDF_HYPERLINKS +PERLMOD_LATEX +PERLMOD_MAKEVAR_PREFIX +PERLMOD_PRETTY +PERL_PATH +PLANTUML_CFG_FILE +PLANTUML_INCLUDE_PATH +PLANTUML_JAR_PATH +PREDEFINED +PROJECT_BRIEF +PROJECT_LOGO +PROJECT_NAME +PROJECT_NUMBER +QCH_FILE +QHG_LOCATION +QHP_CUST_FILTER_ATTRS +QHP_CUST_FILTER_NAME +QHP_NAMESPACE +QHP_SECT_FILTER_ATTRS +QHP_VIRTUAL_FOLDER +QT_AUTOBRIEF +QUIET +RECURSIVE +REFERENCED_BY_RELATION +REFERENCES_LINK_SOURCE +REFERENCES_RELATION +REPEAT_BRIEF +RTF_EXTENSIONS_FILE +RTF_HYPERLINKS +RTF_OUTPUT +RTF_SOURCE_CODE +RTF_STYLESHEET_FILE +SEARCHDATA_FILE +SEARCHENGINE +SEARCHENGINE_URL +SEARCH_INCLUDES +SEPARATE_MEMBER_PAGES +SERVER_BASED_SEARCH +SHORT_NAMES +SHOW_FILES +SHOW_GROUPED_MEMB_INC +SHOW_INCLUDE_FILES +SHOW_NAMESPACES +SHOW_USED_FILES +SIP_SUPPORT +SKIP_FUNCTION_MACROS +SORT_BRIEF_DOCS +SORT_BY_SCOPE_NAME +SORT_GROUP_NAMES +SORT_MEMBERS_CTORS_1ST +SORT_MEMBER_DOCS +SOURCE_BROWSER +SOURCE_TOOLTIPS +STRICT_PROTO_MATCHING +STRIP_CODE_COMMENTS +STRIP_FROM_INC_PATH +STRIP_FROM_PATH +SUBGROUPING +TAB_SIZE +TAGFILES +TCL_SUBST +TEMPLATE_RELATIONS +TOC_EXPAND +TOC_INCLUDE_HEADINGS +TREEVIEW_WIDTH +TYPEDEF_HIDES_STRUCT +UML_LIMIT_NUM_FIELDS +UML_LOOK +USE_HTAGS +USE_MATHJAX +USE_MDFILE_AS_MAINPAGE +USE_PDFLATEX +VERBATIM_HEADERS +WARNINGS +WARN_AS_ERROR +WARN_FORMAT +WARN_IF_DOC_ERROR +WARN_IF_UNDOCUMENTED +WARN_LOGFILE +WARN_NO_PARAMDOC +XML_OUTPUT +XML_PROGRAMLISTING +) + +set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") + +function(add_doxygen_doc) + set(options) + set(oneValueArgs) + set(multiValueArgs DEPENDS ${DOXYGEN_ARGS}) + + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n") + + foreach(ARG ${DOXYGEN_ARGS}) + if(PARSE_${ARG}) + string(REPLACE ";" " " ARG_VALUE ${PARSE_${ARG}}) + file(APPEND ${DOXYGEN_CONFIG_FILE} "\n${ARG} = ${ARG_VALUE}\n") + endif() + endforeach() + + if(PARSE_OUTPUT_DIRECTORY) + if(NOT EXISTS ${PARSE_OUTPUT_DIRECTORY}) + file(MAKE_DIRECTORY ${PARSE_OUTPUT_DIRECTORY}) + endif() + endif() + + if(DOT_EXECUTABLE) + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nDOT_PATH = \"${DOT_EXECUTABLE}\"\n") + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = YES\n") + else() + file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = NO\n") + endif() + + add_custom_target(doxygen + ${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + COMMENT "Building documentation with doxygen" + ) + if(PARSE_OUTPUT_DIRECTORY) + clean_doc_output(${PARSE_OUTPUT_DIRECTORY}) + endif() + mark_as_doc(doxygen) + if(PARSE_DEPENDS) + add_dependencies(doxygen ${PARSE_DEPENDS}) + endif() +endfunction() diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake new file mode 100644 index 0000000000..9f193b2090 --- /dev/null +++ b/cmake/EnableCompilerWarnings.cmake @@ -0,0 +1,110 @@ +################################################################################ +# +# MIT License +# +# Copyright (c) 2017 Advanced Micro Devices, Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################ +# - Enable warning all for gcc/clang or use /W4 for visual studio + +## Strict warning level +if (MSVC) + # Use the highest warning level for visual studio. + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w") + # set(CMAKE_CXX_WARNING_LEVEL 4) + # if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]") + # string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + # else () + # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") + # endif () + + # set(CMAKE_C_WARNING_LEVEL 4) + # if (CMAKE_C_FLAGS MATCHES "/W[0-4]") + # string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") + # else () + # set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4") + # endif () + +else() + foreach(COMPILER C CXX) + set(CMAKE_COMPILER_WARNINGS) + # use -Wall for gcc and clang + list(APPEND CMAKE_COMPILER_WARNINGS + -Wall + -Wextra + -Wcomment + -Wendif-labels + -Wformat + -Winit-self + -Wreturn-type + -Wsequence-point + # Shadow is broken on gcc when using lambdas + # -Wshadow + -Wswitch + -Wtrigraphs + -Wundef + -Wuninitialized + -Wunreachable-code + -Wunused + + -Wno-sign-compare + -Wno-extra-semi-stmt + ) + if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") + list(APPEND CMAKE_COMPILER_WARNINGS + -Weverything + -Wno-c++98-compat + -Wno-c++98-compat-pedantic + -Wno-conversion + -Wno-double-promotion + -Wno-exit-time-destructors + -Wno-extra-semi + -Wno-float-conversion + -Wno-gnu-anonymous-struct + -Wno-gnu-zero-variadic-macro-arguments + -Wno-missing-prototypes + -Wno-nested-anon-types + -Wno-padded + -Wno-return-std-move-in-c++11 + -Wno-shorten-64-to-32 + -Wno-sign-conversion + -Wno-unknown-warning-option + -Wno-unused-command-line-argument + -Wno-weak-vtables + -Wno-covered-switch-default + ) + else() + if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") + # cmake 3.5.2 does not support >=. + if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.1") + list(APPEND CMAKE_COMPILER_WARNINGS + -Wno-ignored-attributes) + endif() + endif() + list(APPEND CMAKE_COMPILER_WARNINGS + -Wno-missing-field-initializers + -Wno-deprecated-declarations + ) + endif() + add_definitions(${CMAKE_COMPILER_WARNINGS}) + endforeach() +endif () diff --git a/composable_kernel/include/gridwise_operation_wrapper.hpp b/composable_kernel/include/gridwise_operation_wrapper.hpp new file mode 100644 index 0000000000..0a1e07ec57 --- /dev/null +++ b/composable_kernel/include/gridwise_operation_wrapper.hpp @@ -0,0 +1,14 @@ +#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER +#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + run_gridwise_operation(Xs... xs) +{ + GridwiseOp{}.Run(xs...); +} + +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..09ea16fa23 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,272 @@ +#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Number of GEMMs = YTilda * XTilda +// GemmM = C +// GemmN = N * HTildaSlice * WTildaSlice +// GemmK = K * YDotSlice * XDotSlice +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + Number, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + constexpr auto IYTilda = Number{}; + constexpr auto IXTilda = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilda); + const auto XDot = math::integer_divide_ceil(X, XTilda); + + const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + const auto IHTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); + const auto IWTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + + const auto IHTildaSliceEnd = + math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildaSliceEnd = + math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; + const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // weight tensor + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilda), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilda), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilda), + make_freeze_transform(IXTilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + +#if 1 + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // output tensor + // this add padding check + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilda), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilda), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + +#if 1 + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilda, HTilda), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilda, WTilda), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(IYTilda), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_freeze_transform(IXTilda), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildaslice_wtildaslice_c_grid_desc, + make_tuple(make_pass_through_transform(C), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..9c60e8c3ac --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,275 @@ +#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: out +// B: wei +// C: in +// Number of GEMMs = YTilda * XTilda +// GemmM = N * HTildaSlice * WTildaSlice +// GemmN = C +// GemmK = K * YDotSlice * XDotSlice +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + Number, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + constexpr auto IYTilda = Number{}; + constexpr auto IXTilda = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilda); + const auto XDot = math::integer_divide_ceil(X, XTilda); + + const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + const auto IHTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); + const auto IWTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + + const auto IHTildaSliceEnd = + math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildaSliceEnd = + math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; + const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // A: output tensor + // this add padding check + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilda), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilda), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + +#if 1 + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // B: weight tensor + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilda), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilda), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilda), + make_freeze_transform(IXTilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + +#if 1 + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilda, HTilda), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilda, WTilda), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(IYTilda), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_freeze_transform(IXTilda), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildaslice_wtildaslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..093a46256d --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -0,0 +1,257 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1( + const TensorDescriptor& wei_k_c_y_x_global_desc, + const TensorDescriptor& in_n_c_hi_wi_global_desc, + const TensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..9aa27884da --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,176 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple( + wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); +} + +template +__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple( + wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..16ae8b470d --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..e81c87d046 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..b0b07505e5 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,132 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = N * Ho * Wo; + const auto GemmN = K; + const auto GemmK = Y * X * C; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..f5cb7f4877 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,132 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM0 = 1 +// GemmM1 = K +// GemmN0 = N0 +// GemmN1 = (N / N0) * Ho * Wo +// GemmK0 = (C / C0) * Y * X +// GemmK1 = C0 +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const N0Type& N0, + const C0Type& C0) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto N1 = N / N0; + const auto C1 = C / C0; + + // weight tensor + const auto wei_gk0_gm0_gm1_gk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_unmerge_transform(make_tuple(I1, K)), + make_unmerge_transform(make_tuple(C0, C1 * Y * X))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_unmerge_transform(make_tuple(C0, C1)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); + + const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor( + in_n0_n1_c0_c1_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C1, Y, X)), + make_pass_through_transform(N0), + make_merge_transform(make_tuple(N1, Ho, Wo)), + make_pass_through_transform(C0)), + make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_n_k_howo_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)); + + const auto out_n0_n1_1_k_howo_grid_desc = + transform_tensor_descriptor(out_n_k_howo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_unmerge_transform(make_tuple(I1, K)), + make_pass_through_transform(Ho * Wo)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor( + out_n0_n1_1_k_howo_grid_desc, + make_tuple(make_pass_through_transform(I1), + make_pass_through_transform(K), + make_pass_through_transform(N0), + make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), + make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return make_tuple( + wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/cluster_descriptor.hpp b/composable_kernel/include/tensor_description/cluster_descriptor.hpp new file mode 100644 index 0000000000..d69bfb70c1 --- /dev/null +++ b/composable_kernel/include/tensor_description/cluster_descriptor.hpp @@ -0,0 +1,33 @@ +#ifndef CK_CLUSTER_DESCRIPTOR_HPP +#define CK_CLUSTER_DESCRIPTOR_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" + +namespace ck { + +template ::type> +__host__ __device__ constexpr auto make_cluster_descriptor( + const Lengths& lengths, + ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) +{ + constexpr index_t ndim_low = Lengths::Size(); + + const auto reordered_lengths = container_reorder_given_new2old(lengths, order); + + const auto low_lengths = generate_tuple( + [&](auto idim_low) { return reordered_lengths[idim_low]; }, Number{}); + + const auto transform = make_merge_transform(low_lengths); + + constexpr auto low_dim_old_top_ids = ArrangeOrder{}; + + constexpr auto up_dim_new_top_ids = Sequence<0>{}; + + return make_single_stage_tensor_adaptor( + make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp new file mode 100644 index 0000000000..42a5a875b7 --- /dev/null +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -0,0 +1,1741 @@ +#ifndef CK_MULTI_INDEX_TRANSFORM_HPP +#define CK_MULTI_INDEX_TRANSFORM_HPP + +#include "common_header.hpp" +#include "multi_index.hpp" + +namespace ck { + +template +struct PassThrough +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr PassThrough() = default; + + __host__ __device__ constexpr PassThrough(const LowLength& low_length) + : up_lengths_{make_tuple(low_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("PassThrough, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct Pad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + RightPadLength right_pad_length_; + + __host__ __device__ constexpr Pad() = default; + + __host__ __device__ constexpr Pad(const LowLength& low_length, + const LeftPadLength& left_pad_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)}, + left_pad_length_{left_pad_length}, + right_pad_length_{right_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || + ((idx_up[Number<0>{}] >= left_pad_length_) && + (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_)); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Pad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_length %d", index_t{left_pad_length_}); + printf("right_pad_length %d", index_t{right_pad_length_}); + printf("}"); + } +}; + +template +struct LeftPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + + __host__ __device__ constexpr LeftPad() = default; + + __host__ __device__ constexpr LeftPad(const LowLength& low_length, + const LeftPadLength& left_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("LeftPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_length_ %d", index_t{left_pad_length_}); + printf("}"); + } +}; + +template +struct RightPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LowLength low_length_; + RightPadLength right_pad_length_; + + __host__ __device__ constexpr RightPad() = default; + + __host__ __device__ constexpr RightPad(const LowLength& low_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + right_pad_length)}, + low_length_{low_length}, + right_pad_length_{right_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("RightPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("low_length_ %d", index_t{low_length_}); + printf("left_pad_length_ %d", index_t{right_pad_length_}); + printf("}"); + } +}; + +// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] +// UpLengths and Coefficients can be either of the followings: +// 1) Tuple of index_t, which is known at run-time, or +// 2) Tuple of Number, which is known at compile-time, or +// 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially +// at compile-time +template ::type = false> +struct Embed +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + UpLengths up_lengths_; + Coefficients coefficients_; + + __host__ __device__ constexpr Embed() = default; + + __host__ __device__ constexpr Embed(const UpLengths& up_lengths, + const Coefficients& coefficients) + : up_lengths_{up_lengths}, coefficients_{coefficients} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) { + idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i]; + }); + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp && + LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_diff_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}( + [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; }); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Embed, "); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("coefficients_ "); + print_multi_index(coefficients_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses regular to do lowering of +// multi-index and use carry-and-borrow check to do lowering of multi-index delta +template +struct Merge_v1_carry_check +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v1_carry_check() = default; + + __host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // normal division + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]); + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] - borrow; + + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t negative_idx_low_tmp = borrow - idx_low[i]; + + bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at run-time each time this function is called, and can be + // very expensive. + LowerIndex idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + bool do_carry = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] + do_carry; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_carry = idx_low_tmp >= low_lengths_[i]; + +#if 0 + // TODO: use exec-mask inline asm, which use 1 VALU + if(do_carry) + { + idx_diff_low(i) -= low_lengths_[i]; + } +#elif 1 + // this use 2 VALU + idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i]; +#elif 1 + // this use 2 VALU + index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i]; + idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry; + + idx_low(I0) += idx_diff_low[I0]; + } + else if constexpr(Hack == 2) + { + // do borrow check on each low dimension in reversed order + // do not need to check the first dimension + bool do_borrow = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] - do_borrow; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_borrow = idx_low_tmp < 0; + +#if 0 + // TODO: use exec-mask inline asm + if(do_borrow) + { + idx_diff_low(i) += low_lengths_[i]; + } +#elif 1 + idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i]; +#elif 1 + index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i]; + idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow; + + idx_low(I0) += idx_diff_low[I0]; + } + else + { + // not implemented + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { +#if 1 + UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#elif 0 + UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#else + UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#endif + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v1_carry_check, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier +{ + template + __host__ __device__ constexpr auto operator()(Number i) const + { + return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_shift +{ + template + __host__ __device__ constexpr auto operator()(Number i) const + { + return MagicDivision::CalculateMagicShift(LowLengths{}[i]); + } +}; + +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct Merge_v2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + using LowLengthsMagicDivisorMultipiler = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsMagicDivisorShift = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_; + LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v2_magic_division() = default; + + __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); }, + Number{})}, + low_lengths_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_magic_divisor_multiplier_[i], + this->low_lengths_magic_divisor_shift_[i]); + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + }); + + idx_low(Number<0>{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_magic_divisor_multiplier_[i], + this->low_lengths_magic_divisor_shift_[i]); + + index_t idx_low_old = idx_low[i]; + + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + + idx_diff_low(i) = idx_low[i] - idx_low_old; + }); + + idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{}); + + idx_low(Number<0>{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_magic_divisor_multiplier_); + printf("low_lengths_magic_divisor_shift_ "); + print_multi_index(low_lengths_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct Merge_v2r2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsScanMagicDivisorShift = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_; + LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v2r2_magic_division() = default; + + __host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + low_lengths_scan_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); }, + Number{})}, + low_lengths_scan_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + index_t idx_low_old = idx_low[i]; + + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + idx_diff_low(i) = idx_low[i] - idx_low_old; + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_diff_low(Number{}) = tmp - idx_low[Number{}]; + + idx_low(Number{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v2r2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan "); + print_multi_index(low_lengths_scan_); + printf("low_lengths_scan_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_scan_magic_divisor_multiplier_); + printf("low_lengths_scan_magic_divisor_shift_ "); + print_multi_index(low_lengths_scan_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct UnMerge +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + using UpLengthsScan = + decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number<1>{})); + + UpLengths up_lengths_; + UpLengthsScan up_lengths_scan_; + + __host__ __device__ constexpr UnMerge() = default; + + __host__ __device__ constexpr UnMerge(const UpLengths& up_lengths) + : up_lengths_{up_lengths}, + up_lengths_scan_{ + container_reverse_exclusive_scan(up_lengths, math::multiplies{}, Number<1>{})} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + if constexpr(!Use24BitIntegerCalculation) + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}( + [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; }); + } + else + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}([&](auto i) { + idx_low(Number<0>{}) = + (0x00ffffff & idx_low[Number<0>{}]) + + (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]); + }); + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + CalculateLowerIndex(idx_diff_low, idx_diff_up); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("UnMerge, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("up_lengths_scan_"); + print_multi_index(up_lengths_scan_); + printf("}"); + } +}; + +template +struct Freeze +{ + LowerIndex low_idx_; + + __host__ __device__ constexpr Freeze() = default; + + __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& /* idx_up */) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = low_idx_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& /* idx_low */, + const UpIdx& /* idx_up_new */, + Number) + { + idx_diff_low(Number<0>{}) = 0; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Freeze"); + printf("low_idx_ %d", index_t{low_idx_}); + } +}; + +// Insert a dangling upper dimension without lower dimension +template +struct Insert +{ + using UpLengths = decltype(make_tuple(UpperLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr Insert() = default; + + __host__ __device__ constexpr Insert(const UpperLength& up_length) + : up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + template + __host__ __device__ static void + UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number) + { + static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Insert"); + print_multi_index(up_lengths_); + } +}; + +template +struct Vectorize +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(UpLength{})); + + UpLengths up_lengths_; + VectorSize vector_size_; + + __host__ __device__ constexpr Vectorize() = default; + + __host__ __device__ constexpr Vectorize(const VectorSize& vector_size, + const UpLength& up_length) + : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}]; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = vector_size_ * idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Vectorize, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct Slice +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{})); + + UpLengths up_lengths_; + SliceBegin slice_begin_; + SliceEnd slice_end_; + + __host__ __device__ constexpr Slice() = default; + + __host__ __device__ constexpr Slice(const LowLength&, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) + : up_lengths_{make_tuple(slice_end - slice_begin)}, + slice_begin_{slice_begin}, + slice_end_{slice_end} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Slice, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("slice_begin_ %d", index_t{slice_begin_}); + printf("slice_end %d", index_t{slice_end_}); + printf("}"); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp new file mode 100644 index 0000000000..6d4e01888b --- /dev/null +++ b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp @@ -0,0 +1,103 @@ +#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP +#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP + +#include "common_header.hpp" +#include "multi_index_transform.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length) +{ + return PassThrough{low_length}; +} + +template +__host__ __device__ constexpr auto +make_pad_transform(const LowLength& low_length, + const LeftPad& left_pad, + const RightPad& right_pad, + integral_constant = integral_constant{}) +{ + return Pad{low_length, left_pad, right_pad}; +} + +template +__host__ __device__ constexpr auto make_left_pad_transform( + const LowLength& low_length, + const LeftPadLength& left_pad, + integral_constant = integral_constant{}) +{ + return LeftPad{low_length, left_pad}; +} + +template +__host__ __device__ constexpr auto make_right_pad_transform( + const LowLength& low_length, + const RightPadLength& right_pad, + integral_constant = integral_constant{}) +{ + return RightPad{low_length, right_pad}; +} + +template ::type = false> +__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, + const Coefficients& coefficients) +{ + return Embed{up_lengths, coefficients}; +} + +template +__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) +{ +#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION + return Merge_v1_carry_check{low_lengths}; +#else +#if 1 + return Merge_v2_magic_division{low_lengths}; +#else + return Merge_v2r2_magic_division{low_lengths}; +#endif +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v2_magic_division(const LowLengths& low_lengths) +{ + return Merge_v2_magic_division{low_lengths}; +} + +template +__host__ __device__ constexpr auto make_unmerge_transform( + const UpLengths& up_lengths, + integral_constant = integral_constant{}) +{ + return UnMerge{up_lengths}; +} + +template +__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) +{ + return Freeze{low_idx}; +} + +template +__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) +{ + return Slice{low_length, slice_begin, slice_end}; +} + +template +__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, + const UpLength& up_length) +{ + return Vectorize{vector_size, up_length}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp new file mode 100644 index 0000000000..3b647e433a --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -0,0 +1,464 @@ +#ifndef CK_TENSOR_ADAPTOR_HPP +#define CK_TENSOR_ADAPTOR_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// BottomDimensionHiddenIds : Sequence<...> +// TopDimensionHiddenIds : Sequence<...> +template +struct TensorAdaptor +{ + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss() + { + return LowerDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss() + { + return UpperDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetTopDimensionHiddenIds() + { + return TopDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto GetBottomDimensionHiddenIds() + { + return BottomDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_top) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_top = Number{}; + + constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + __host__ __device__ static constexpr index_t GetNumOfBottomDimension() + { + return BottomDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfTopDimension() + { + return TopDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension(); + constexpr static index_t ndim_top_ = GetNumOfTopDimension(); + + using HiddenIndex = MultiIndex; + using BottomIndex = MultiIndex; + using TopIndex = MultiIndex; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: + __host__ __device__ constexpr TensorAdaptor() = default; + + __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) + : transforms_{transforms}, element_size_{InitializeElementSize(transforms)} + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionHiddenIdss::Size() == ntransform_ && + UpperDimensionHiddenIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = GetNumOfTransform(); + constexpr index_t ndim_hidden = GetNumOfHiddenDimension(); + + MultiIndex idx_hidden; + + // initialize uppest index + set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top); + + // calculate hidden index + static_for{}([&](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = GetTransforms().At(itran); + constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran); + constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= + remove_cv_t>::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("TensorAdaptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionHiddenIds:"); + LowerDimensionHiddenIdss{}.At(i).Print(); + printf("UpperDimensionHiddenIds:"); + UpperDimensionHiddenIdss{}.At(i).Print(); + }); + + printf("BottomDimensionHiddenIds:"); + BottomDimensionHiddenIds::Print(); + printf("TopDimensionHiddenIds:"); + TopDimensionHiddenIds::Print(); + + printf("}"); + } + + private: + Transforms transforms_; + ElementSize element_size_; +}; + +template +__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0, + const TensorAdaptor1& adaptor1) +{ + static_assert(TensorAdaptor0::GetNumOfTopDimension() == + TensorAdaptor1::GetNumOfBottomDimension(), + "wrong!"); + + // all_transforms = transform0 + transform1 + const auto all_transforms = + container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms()); + + // shift + constexpr index_t adaptor0_max_hidden_id = [&]() { + index_t adaptor0_max_hidden_id_ = NumericLimits::Min(); + + static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, + TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value); + }); + + constexpr index_t ndim_up = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, + TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor0_max_hidden_id_; + }(); + + constexpr index_t adaptor1_min_hidden_id = [&]() { + index_t adaptor1_min_hidden_id_ = NumericLimits::Max(); + + static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + // get the min of all lower dimenions, but not bottom dimension (because their id will + // be matched with top id from adaptor0) + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t low_dim_hidden_id = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value; + + bool is_bottom_dim = false; + static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) { + if constexpr(low_dim_hidden_id == + TensorAdaptor1::GetBottomDimensionHiddenIds()[i]) + { + is_bottom_dim = true; + } + }); + + if(!is_bottom_dim) + { + adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id); + } + }); + + constexpr index_t ndim_up = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + // get the min of all upper dimensions + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor1_min_hidden_id_ = + math::min(adaptor1_min_hidden_id_, + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor1_min_hidden_id_; + }(); + + constexpr index_t adaptor1_hidden_id_shift = + adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id; + + constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension(); + + // all_low_dim_hidden_idss = + // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1)) + constexpr auto low_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size(); + + constexpr auto low_dim_hidden_ids_1 = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; + + // sequence in, sequence out + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr + { + auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); + + // shift hidden id so every dim id is unique + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift; + }); + + // match hidden id + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) + { + low_dim_hidden_ids_1_mod_(idim_low_1) = + TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; + } + }); + }); + + return low_dim_hidden_ids_1_mod_; + } + (); + + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_low_dim_hidden_idss = + container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1); + + // all_up_dim_hidden_idss = + // up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1) + constexpr auto up_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size(); + + constexpr auto up_dim_hidden_ids_1 = + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; + + // sequence in, constexpr tuple out + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr + { + auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); + + // shift hidden id + static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { + up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift; + }); + + return up_dim_hidden_ids_1_mod_; + } + (); + + // constexpr tuple to sequence + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_up_dim_hidden_idss = + container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1); + + // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0 + constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds(); + + // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) + constexpr auto top_dim_hidden_ids = + TensorAdaptor1::GetTopDimensionHiddenIds() + Number{}; + + // put everything together + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms}; +} + +// Transforms: Tuple +// LowerDimensionOldTopIdss: Tuple, ...> +// UpperDimensionNewTopIdss: Tuple, ...> +template +__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms, + LowerDimensionOldTopIdss, + UpperDimensionNewTopIdss) +{ + constexpr index_t ntransform = Transforms::Size(); + + static_assert(LowerDimensionOldTopIdss::Size() == ntransform && + UpperDimensionNewTopIdss::Size() == ntransform, + "wrong!"); + + // sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss + constexpr auto all_low_dim_old_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{}); + + constexpr auto all_up_dim_new_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + + constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size(); + constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size(); + + // low_dim_hidden_idss + constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{}; + + // up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom + constexpr auto up_dim_hidden_idss = generate_tuple( + [](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number{}; }, + Number{}); + + // bottom_dim_hidden_ids + constexpr auto bottom_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{}; + + // top_dim_hidden_ids + constexpr auto top_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number{}; + + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms}; +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) +{ + return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp new file mode 100644 index 0000000000..a6a57ba63b --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -0,0 +1,597 @@ +#ifndef CK_TENSOR_DESCRIPTOR_HPP +#define CK_TENSOR_DESCRIPTOR_HPP + +#include "common_header.hpp" +#include "multi_index_transform.hpp" + +namespace ck { + +template +struct TensorCoordinate; + +template +struct TensorCoordinateStep; + +// Transforms: Tuple +// LowerDimensionIdss : Tuple, ...> +// UpperDimensionIdss : Tuple, ...> +// VisibleDimensionIds> : Sequence<...> +template +struct TensorDescriptor +{ + // TODO make these private + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ static constexpr index_t GetNumOfVisibleDimension() + { + return VisibleDimensionIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_visible) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_visible = Number{}; + + constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + + using VisibleIndex = MultiIndex; + using HiddenIndex = MultiIndex; + using Coordinate = TensorCoordinate; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: + __host__ __device__ constexpr TensorDescriptor() = default; + + __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, + ElementSpaceSize element_space_size) + : transforms_{transforms}, + element_size_{InitializeElementSize(transforms)}, + element_space_size_{element_space_size} + + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionIdss::Size() == ntransform_ && + UpperDimensionIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ static constexpr index_t GetNumOfDimension() + { + return GetNumOfVisibleDimension(); + } + + template + __host__ __device__ constexpr auto GetLength(Number) const + { + static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range"); + + constexpr auto tmp = GetTransformAndItsUpperDimension(Number{}); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + return transforms_[Number{}].GetUpperLengths()[Number{}]; + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + + __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } + + template + __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const + { + static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); + + return make_tensor_coordinate(*this, idx).GetOffset(); + } + + // TODO make these private + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionIdss() + { + return LowerDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionIdss() + { + return UpperDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetVisibleDimensionIds() + { + return VisibleDimensionIds{}; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= + remove_cv_t>::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("TensorDescriptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionIds:"); + LowerDimensionIdss{}.At(i).Print(); + printf("UpperDimensionIds:"); + UpperDimensionIdss{}.At(i).Print(); + }); + printf("}"); + + VisibleDimensionIds::Print(); + } + + // TODO make these private + Transforms transforms_; + ElementSize element_size_; + ElementSpaceSize element_space_size_; +}; + +template +struct TensorCoordinate +{ + // TODO make these private + static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); + + using HiddenIndex = MultiIndex; + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr TensorCoordinate() = default; + + __host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden) + : idx_hidden_{idx_hidden} + { + } + + __host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); } + + __host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; } + + // TODO make these private + __host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; } + + __host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; } + + __host__ __device__ constexpr auto GetVisibleIndex() const + { + return get_container_subset(idx_hidden_, VisibleDimensionIds{}); + } + + // TODO make these private + HiddenIndex idx_hidden_; +}; + +template +struct TensorCoordinateStep +{ + // TODO make these private + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr TensorCoordinateStep() = default; + + __host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible, + const MultiIndex& do_transforms) + : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} + { + } + + __host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); } + + // TODO make these private + __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const + { + return idx_diff_visible_; + } + + VisibleIndex idx_diff_visible_; + MultiIndex do_transforms_; + + // HACK: control UpdateLowerIndex() + static constexpr UpdateLowerIndexHack update_lower_index_hack_; +}; + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor, and to put it outside the scope where it is used +// (transform_tensor_descriptor) because template cannot be defined inside a function +// template +template +struct lambda_get_up_dim_num +{ + template + __host__ __device__ constexpr auto operator()(I) const + { + using Tran = remove_reference_t; + return Number{}; + } +}; + +template +__host__ __device__ constexpr auto +transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, + const NewTransforms& new_transforms, + NewLowerDimensionOldVisibleIdss, + NewUpperDimensionNewVisibleIdss) +{ + // sanity check + { + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + } + + // lower dimension's hidden idss + // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of + // sequences) + constexpr auto low_dim_hidden_idss = transform_tuples( + // convert lower dimension visible ids (a sequence) to hidden ids (a sequence) + [](auto low_dim_visible_ids) constexpr { + return transform_sequences( + // convert lower dimension visible id to hidden id + [](auto low_dim_visible_id) constexpr { + return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id]; + }, + low_dim_visible_ids); + }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr index_t num_new_transform = NewTransforms::Size(); + + // upper dimension's hidden idss + constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension(); + + constexpr auto up_dim_numbers = + generate_sequence(lambda_get_up_dim_num{}, Number{}); + + constexpr auto up_dim_numbers_scan = merge_sequences( + Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); + + constexpr auto up_dim_hidden_idss = generate_tuple( + [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + return + typename arithmetic_sequence_gen::type{}; + }, + Number{}); + + // new visible dimension's hidden ids + constexpr auto unordered_new_visible_dim_hidden_ids = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + + constexpr auto new_visible_dim_unordered2ordered = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + constexpr auto new_visible_dim_hidden_ids = + unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); + + // put everything together + const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms); + + constexpr auto all_low_dim_hidden_idss = + container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); + + constexpr auto all_up_dim_hidden_idss = + container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); + + const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms, + element_space_size}; +} + +template +__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc, + const VisibleIndex& idx_visible) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + MultiIndex idx_hidden; + + // initialize visible index + set_container_subset(idx_hidden, visible_dim_ids, idx_visible); + + // calculate hidden index + static_for{}([&tensor_desc, &idx_hidden](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return TensorCoordinate{idx_hidden}; +} + +// UpdateLowerIndexHack: Sequence<...> +// HACK: control UpdateLowerIndex +template +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible, + UpdateLowerIndexHack) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!"); + + // use index_t for boolean type + auto do_transforms = make_zero_multi_index(); + auto is_non_zero_diff = make_zero_multi_index(); + + // decide do_transform by checkout non-zero index diff components + MultiIndex non_zero_diff_pick_visible; + + static_for<0, ndim_visible, 1>{}( + [&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); }); + + set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible); + + static_for{}([&](auto itran) { + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up); + + MultiIndex non_zero_diff_pick_low; + + // if any of upper index diff components is non-zero, then + // 1) Need to do this transform + // 2) all components of lower index diff will assume to be non-zero and need to be + // computed + const bool idx_diff_up_has_non_zero = container_reduce( + non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false); + + do_transforms(itran) = idx_diff_up_has_non_zero; + + static_for<0, dims_low.Size(), 1>{}( + [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; }); + + set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); + }); + + return TensorCoordinateStep{idx_diff_visible, + do_transforms}; +} + +template +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible) +{ + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + return make_tensor_coordinate_step( + TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc, + TensorCoord& coord, + const TensorCoordStep& coord_step) +{ + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + // this is what needs to be calculated + auto idx_diff_hidden = make_zero_multi_index(); + + // initialize visible index diff + set_container_subset( + idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff()); + + // this is what needs to be updated + auto& idx_hidden = coord.GetHiddenIndex(); + + // update visible index + auto idx_hidden_pick_visible = + get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); + + idx_hidden_pick_visible += coord_step.GetIndexDiff(); + + set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); + + // update rest of hidden index + static_for{}([&](auto itran) { + if(coord_step.do_transforms_[itran]) + { + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up_new = get_container_subset(idx_hidden, dims_up); + auto idx_low = get_container_subset(idx_hidden, dims_low); + const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up); + + MultiIndex idx_diff_low; + + // HACK: control UpdateLowerIndex for Merge using hack + constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran); + + tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); + + set_container_subset(idx_diff_hidden, dims_low, idx_diff_low); + set_container_subset(idx_hidden, dims_low, idx_low); + } + }); +} + +template +__host__ __device__ constexpr bool +coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + bool valid = true; + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + const auto& idx_hidden = coord.GetHiddenIndex(); + + static_for{}([&tensor_desc, &idx_hidden, &valid](auto itran) { + const auto tran = tensor_desc.GetTransforms().At(itran); + + // check validity, only if current transformation does not always has a valid mapping + if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex()) + { + const auto idx_up = + get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran)); + + // Comment: using valid = valid && .. will result in weird control flow in ISA + valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up); + } + }); + + return valid; +} + +template +__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + // check visible index + const auto& idx_visible = coord.GetVisibleIndex(); + + bool is_visible_index_valid = true; + + static_for<0, TensorDesc::GetNumOfDimension(), 1>{}( + [&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) { + is_visible_index_valid = + is_visible_index_valid && + (idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i)); + }); + + // check other hidden index + return is_visible_index_valid && + coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord); +} + +template +using TensorCoordinate_t = decltype(make_tensor_coordinate( + TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + +template +using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( + TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp new file mode 100644 index 0000000000..ad75f9245e --- /dev/null +++ b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp @@ -0,0 +1,149 @@ +#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP +#define CK_TENSOR_DESCRIPTOR_HELPER_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "multi_index_transform_helper.hpp" + +namespace ck { + +/* + * These functions create tensor descriptor at runtime. If they are not constexpr, you will + * likely see usage of scratch memory during construction of these tensor descriptors. So + * it's better to call these functions on host and then pass the constructed tensor descritpors + * to GPU. If the tensor descritpors being constructed are constexpr, then you can call these + * functions on GPU without worrying about scratch memory usage. + */ + +#if CK_WORKAROUND_SWDEV_275126 +template +__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + Number i, + AccOld acc_old) +{ + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < Lengths::Size() - 1) + { + return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } +} +#endif + +template ::type = false> +__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple& lengths, + const Tuple& strides) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_embed_transform(lengths, strides)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + +#if !CK_WORKAROUND_SWDEV_275126 + // rocm-4.1 compiler would crash for recursive labmda + // recursive function for reduction + auto f = [&](auto fs, auto i, auto acc_old) { + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < N - 1) + { + return fs(fs, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } + }; + + const auto element_space_size = f(f, Number<0>{}, Number<1>{}); +#else + const auto element_space_size = + calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); +#endif + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +// Lengths... can be: +// 1) index_t, which is known at run-time +// 2) Number<>, which is known at compile-time +template +__host__ __device__ constexpr auto +make_naive_tensor_descriptor_packed(const Tuple& lengths) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_unmerge_transform(lengths)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{}); + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +template +__host__ __device__ constexpr auto +make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align align) +{ + constexpr auto I1 = Number<1>{}; + + constexpr index_t N = sizeof...(Lengths); + + const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number{}], align); + + auto strides = generate_tuple( + [&](auto i) { + if constexpr(i.value == N - 1) + { + return I1; + } + else if constexpr(i.value == N - 2) + { + return Number{}; + } + else + { + return container_reduce(lengths, + math::multiplies{}, + Number{}, + i + I1, + Number{}, + I1); + } + }, + Number{}); + + return make_naive_tensor_descriptor(lengths, strides); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp new file mode 100644 index 0000000000..35ff66a2b0 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r2.hpp @@ -0,0 +1,394 @@ +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_contraction_dlops.hpp" + +namespace ck { + +// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. AKMBlockDesc is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BKNBlockDesc is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CM0M1N0N1ThreadDesc is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template < + index_t BlockSize, + typename FloatA, + typename FloatB, + typename FloatC, + typename AKMBlockDesc, + typename BKNBlockDesc, + index_t M1PerThreadM11, + index_t N1PerThreadN11, + index_t KPerThread, + index_t M1N1ThreadClusterM100, + index_t M1N1ThreadClusterN100, + index_t M1N1ThreadClusterM101, + index_t M1N1ThreadClusterN101, + index_t AThreadCopyScalarPerVector_M11, + index_t BThreadCopyScalarPerVector_N11, + typename enable_if::type = false> +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t K = AKMBlockDesc{}.GetLength(I0); + static constexpr index_t M = AKMBlockDesc{}.GetLength(I1); + static constexpr index_t N = BKNBlockDesc{}.GetLength(I1); + + static constexpr index_t M100 = M1N1ThreadClusterM100; + static constexpr index_t N100 = M1N1ThreadClusterN100; + + static constexpr index_t M101 = M1N1ThreadClusterM101; + static constexpr index_t N101 = M1N1ThreadClusterN101; + + static constexpr index_t M11 = M1PerThreadM11; + static constexpr index_t N11 = N1PerThreadN11; + + static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11; + static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11; + + static constexpr index_t M0 = M / M1; + static constexpr index_t N0 = N / N1; + + __host__ __device__ static constexpr auto + MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */) + { + const auto a_k_m0_m1_block_desc = transform_tensor_descriptor( + AKMBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_block_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */) + { + const auto b_k_n0_n1_block_desc = transform_tensor_descriptor( + BKNBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_block_desc; + } + + __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M, N] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor; + } + + __host__ __device__ static constexpr auto + MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M0, M1, N0, N1] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor; + } + + __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths() + { + return Sequence{}; + } + + static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{}); + static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{}); + + public: + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2() + : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} + { + static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == M101 * M100 * N101 * N100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(M % M1 == 0 && N % N1 == 0, "wrong!"); + + static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id) + { + // lower: [M0, M1, N0, N1] + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor(); + + // lower: [M0, M100, M101, M11, N0, N100, N101, N11] + // upper: [Tid, M0, M11, N0, N11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)), + make_pass_through_transform(M0), + make_pass_through_transform(M11), + make_pass_through_transform(N0), + make_pass_through_transform(N11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; } + + __host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; } + + template + __device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 && + CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_k_m0_m1_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_k_n0_n1_thread_desc_.GetElementSpaceSize()); + + constexpr auto threadwise_gemm = + ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1, + Sequence<1, M1PerThreadM11>, + Sequence<1, N1PerThreadN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of k + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[K, M0, M1] + static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // B[K, N0, N1] + static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + AThreadCopyScalarPerVector_M11, + 1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + BThreadCopyScalarPerVector_N11, + 1>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp new file mode 100644 index 0000000000..26ca0bf111 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v2r3.hpp @@ -0,0 +1,410 @@ +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" +#include "threadwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_contraction_dlops.hpp" + +namespace ck { + +// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 +// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template + typename BM10BN10ThreadClusterBN10Xs, // Sequence + index_t AThreadCopyScalarPerVector_BM11, + index_t BThreadCopyScalarPerVector_BN11, + typename enable_if::type = false> +struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); + static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); + static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); + static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); + + static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; + static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; + + static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; + static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; + + static constexpr index_t BM11 = BM1PerThreadBM11; + static constexpr index_t BN11 = BN1PerThreadBN11; + + static constexpr index_t BM1 = BM100 * BM101 * BM11; + static constexpr index_t BN1 = BN100 * BN101 * BN11; + + static constexpr index_t BM0 = BM / BM1; + static constexpr index_t BN0 = BN / BN1; + + __host__ __device__ static constexpr auto + MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) + { + const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( + a_block_desc_bk0_bm_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_block_bk0_bm0_bm1_bk1; + } + + __host__ __device__ static constexpr auto + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) + { + const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( + b_block_desc_bk0_bn_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_block_desc_bk0_bn0_bn1_bk1; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM, BN] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM0, BM1, BN0, BN1] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; + } + + __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() + { + return Sequence{}; + } + + static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = + MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); + + static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); + + public: + __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)} + { + static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && + BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == BM101 * BM100 * BN101 * BN100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); + + static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == + BBlockDesc_BK0_BN_BK1{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO remove this restriction + static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && + BM10BN10ThreadClusterBN10Xs::Size() == 2, + "wrong!"); + + // TODO: remove this restriction + static_assert(BM0 == 2 && BN0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) + { + // lower: [BM0, BM1, BN0, BN1] + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + constexpr auto adaptor0 = + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); + + // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // upper: [Tid, BM0, BM11, BN0, BN11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), + make_pass_through_transform(BM0), + make_pass_through_transform(BM11), + make_pass_through_transform(BN0), + make_pass_through_transform(BN11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + template + __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(BM0 == 2 && BN0 == 2 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); + + constexpr auto threadwise_contraction = + ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + FloatA, + FloatB, + FloatC, + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + CThreadDesc_BM0_BM11_BN0_BN11, + Sequence, + Sequence<1, BM1PerThreadBM11>, + Sequence<1, BN1PerThreadBN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of bk0 + static_for{}([&](auto bk0) { + // read A_sub_0 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I1, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I1, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[BK0, BM0, BM1, BK1] + static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + // B[BK0, BN0, BN1, BK1] + static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< + FloatA, + FloatA, + decltype(a_block_desc_bk0_bm0_bm1_bk1_), + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< + FloatB, + FloatB, + decltype(b_block_desc_bk0_bn0_bn1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp new file mode 100644 index 0000000000..03f889649e --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp @@ -0,0 +1,185 @@ +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "threadwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 +{ + struct MatrixIndex + { + index_t k; + index_t h; + index_t w; + }; + + // HACK: fix this @Jing Zhang + static constexpr index_t KPerThreadSubC = 4; + + static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1>, + 1, + ThreadGemmADataPerRead_K, + 1>; + + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() + : c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())}, + a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)} + { + static_assert(BlockMatrixA::IsKnownAtCompileTime() && + BlockMatrixB::IsKnownAtCompileTime() && + ThreadMatrixC::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), + "wrong! K dimension not consistent\n"); + + constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed + constexpr index_t H = BlockMatrixB{}.GetLength(I2); + constexpr index_t W = BlockMatrixB{}.GetLength(I3); + + static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0, + "wrong! Cannot evenly divide work among\n"); + + constexpr auto KThreadCluster = K / KPerThread; + constexpr auto HThreadCluster = H / HPerThread; + constexpr auto WThreadCluster = W / WPerThread; + + static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, + "wrong! wrong blocksize\n"); + } + + __device__ static constexpr auto GetThreadMatrixCLengths() + { + return Sequence{}; + } + + __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) + { + constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{}); + constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{}); + + constexpr auto num_w_threads = W / WPerThread; + constexpr auto num_h_threads = H / HPerThread; + constexpr auto num_hw_threads = num_w_threads * num_h_threads; + + index_t k_thread_id = thread_id / num_hw_threads; + index_t hw_thread_id = thread_id % num_hw_threads; + + index_t h_thread_id = hw_thread_id / num_w_threads; + index_t w_thread_id = hw_thread_id % num_w_threads; + + return MatrixIndex{k_thread_id, h_thread_id, w_thread_id}; + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BThreadBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + + constexpr auto EPerBlock = a_block_mtx.GetLength(I0); + + // HACK: fix this @Jing Zhang + constexpr auto HoPerThreadSubC = 2; + constexpr auto WoPerThreadSubC = 2; + + static_assert(KPerThread % KPerThreadSubC == 0, ""); + static_assert(HPerThread % HoPerThreadSubC == 0, ""); + static_assert(WPerThread % WoPerThreadSubC == 0, ""); + + // thread A buffer for GEMM + StaticBuffer + a_thread_buf; + + constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; + + static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) { + static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) { + a_thread_copy_.Run(a_block_mtx, + make_tuple(e_begin, k_begin), + a_block_buf, + a_thread_mtx_, + make_tuple(I0, I0), + a_thread_buf); + + static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) { + static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) { + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0), + b_thread_buf, + make_tuple(e_begin, I0, h_begin, w_begin), + c_thread_buf, + make_tuple(k_begin, I0, h_begin, w_begin)); + }); + }); + }); + }); + } + + template + __device__ void MoveASliceWindow(const BlockMatrixA&, + const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) + { + a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx); + } + + private: + MatrixIndex c_thread_begin_mtx_idx_; + + AThreadCopy a_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp new file mode 100644 index 0000000000..ee6a0b7427 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -0,0 +1,524 @@ +#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP +#define CK_BLOCKWISE_GEMM_XDLOPS_HPP + +#include "common_header.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "xdlops_gemm.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 +{ + + using CIndex = MultiIndex<2>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t WaveSize = 64; + + static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); + static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + + static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); + static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t MWaves = M1 / MPerWave; + static constexpr index_t NWaves = N1 / NPerWave; + + static constexpr index_t MRepeat = M0; + static constexpr index_t NRepeat = N0; + + __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + + __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } + + __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, m_offset, 0); + } + else + { + const index_t m_offset = waveId_m * MPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, m_offset, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, n_offset, 0); + } + else + { + const index_t n_offset = waveId_n * NPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, n_offset, 0); + } + } + + template + __device__ static CIndex + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + + const index_t waveId = get_thread_local_1d_id() / WaveSize; + + const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; + const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + + return CIndex{m_offset, n_offset}; + } + + __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() + : a_thread_copy_{CalculateAThreadOriginDataIndex()}, + b_thread_copy_{CalculateBThreadOriginDataIndex()} + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + "wrong! K1 dimension not consistent"); + + static_assert(BlockSize == MWaves * NWaves * WaveSize, + "BlockSize != MWaves * NWaves * WaveSize\n"); + + static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); + + static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + vector_type a_thread_vec; + + vector_type b_thread_vec; + + static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { + // read A + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + using mfma_input_type = + typename vector_type::type; + + static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { + a_thread_vec.template AsType()(Number{}) = a_thread_buf[Number{}]; + }); + + static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { + b_thread_vec.template AsType()(Number{}) = b_thread_buf[Number{}]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + xdlops_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf); + }); + }); + }); + } + + private: + // A[K, M] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + + // B[K, N] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + K1, + 1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + K1, + 1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +template +struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline +{ + + using CIndex = MultiIndex<2>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t WaveSize = 64; + + static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); + static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + + static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); + static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + + static constexpr index_t MWaves = M1 / MPerWave; + static constexpr index_t NWaves = N1 / NPerWave; + + static constexpr index_t MRepeat = M0; + static constexpr index_t NRepeat = N0; + + __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + + __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } + + __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, m_offset, 0); + } + else + { + const index_t m_offset = waveId_m * MPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, m_offset, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, n_offset, 0); + } + else + { + const index_t n_offset = waveId_n * NPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, n_offset, 0); + } + } + + template + __device__ static CIndex + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + + const index_t waveId = get_thread_local_1d_id() / WaveSize; + + const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; + const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + + return CIndex{m_offset, n_offset}; + } + + __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() + : a_thread_copy_{CalculateAThreadOriginDataIndex()}, + b_thread_copy_{CalculateBThreadOriginDataIndex()} + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + "wrong! K1 dimension not consistent"); + + static_assert(BlockSize == MWaves * NWaves * WaveSize, + "BlockSize != MWaves * NWaves * WaveSize\n"); + + static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); + + static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + // read A_sub_0 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(I0, I1, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(I0, I1, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I1, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I1, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + } + + private: + // A[K, M] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + + // B[K, N] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + 1, // K1, + 1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + 1, // K1, + 1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp new file mode 100644 index 0000000000..0214b71352 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer.hpp @@ -0,0 +1,170 @@ +#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP +#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseTensorSliceTransfer_v4 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + using Index = MultiIndex; + + __device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + : threadwise_transfer_( + src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + // SrcMoveSliceWindowStepHack to control index calculation move slice window + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& step, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow( + src_desc, step, src_move_slice_window_step_hack); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp new file mode 100644 index 0000000000..6b2d2d5231 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_tensor_slice_transfer_v2.hpp @@ -0,0 +1,156 @@ +#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP +#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer_v2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseTensorSliceTransfer_v4r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + using Index = MultiIndex; + + __device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + : threadwise_transfer_( + src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + // SrcMoveSliceWindowStepHack to control index calculation move slice window + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& step, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow( + src_desc, step, src_move_slice_window_step_hack); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp new file mode 100644 index 0000000000..fe56d0d813 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_contraction_dlops_v1r2.hpp @@ -0,0 +1,659 @@ +#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP +#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_dlops_v2r3.hpp" +#include "blockwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_dlops_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10) +{ + constexpr index_t shared_block_size = + GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseContraction::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // GM0 and GN0 need to known at compile-time + static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0); + static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2); + static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // lds max alignment + // TODO: part of them should be moved into blockwise-gemm + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = GK1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! GM0 and GN0 need to be known at compile-time"); + + const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); + const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return ( + (GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) && + GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) && + GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) && + GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) && + GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) && + GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) && + GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) && + GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) && + GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) && + GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) && + (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0)); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr index_t GM11 = GM1PerBlockGM11; + constexpr index_t GN11 = GN1PerBlockGN11; + + const index_t GM10 = GM1 / GM11; + const index_t GN10 = GN1 / GN11; + + const index_t grid_size = GM10 * GN10; + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0) + { + const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0) + { + const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1) + { + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); + + const auto GM11 = Number{}; + const auto GM10 = GM1 / GM11; + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor( + a_grid_desc_gk0_gm0_gm1_gk1, + make_tuple(make_pass_through_transform(GK0), + make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return a_grid_desc_gk0_gm0_gm10_gm11_gk1; + } + + __host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1) + { + const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0); + const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); + + const auto GN11 = Number{}; + const auto GN10 = GN1 / GN11; + + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor( + b_grid_desc_gk0_gn0_gn1_gk1, + make_tuple(make_pass_through_transform(GK0), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11)), + make_pass_through_transform(GK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return b_grid_desc_gk0_gn0_gn10_gn11_gk1; + } + + __host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + constexpr auto BM = GM0 * GM11; + constexpr auto BN = GN0 * GN11; + + constexpr auto BM1 = + Number{}; + constexpr auto BN1 = + Number{}; + + constexpr auto BM0 = BM / BM1; + constexpr auto BN0 = BN / BN1; + + const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor( + c_grid_desc_gm0_gm1_gn0_gn1, + make_tuple(make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor( + c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_merge_transform(make_tuple(GM0, GM11)), + make_pass_through_transform(GN10), + make_merge_transform(make_tuple(GN0, GN11))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor( + c_gm10_bm_gn10_bn_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_unmerge_transform(make_tuple(BM0, BM1)), + make_pass_through_transform(GN10), + make_unmerge_transform(make_tuple(BN0, BN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; + } + + __host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10( + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) + { + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(GM10, GN10))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_grid_block_cluster_blockid_to_gm10_gn10; + } + + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = + decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{})); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = + decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{})); + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = + decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{})); + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); + + const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); + + // divide block work by [GM10, GN10] + const auto c_gm10_gn10_block_cluster_idx = + c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]); + const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]); + + // lds max alignment + // TODO: part of them should be moved into blockwise-gemm + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = GK1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); + + // A matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); + + // B matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); + + static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == + a_block_desc_gk0_bm_gk1.GetElementSpaceSize() && + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() == + b_block_desc_gk0_bn_gk1.GetElementSpaceSize(), + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperationEnum_t::Set, + Sequence, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1), + decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + make_multi_index(0, 0, igm10, 0, 0), + a_block_desc_gk0_gm0_gm10_gm11_gk1, + make_multi_index(0, 0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperationEnum_t::Set, + Sequence, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1), + decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + make_multi_index(0, 0, ign10, 0, 0), + b_block_desc_gk0_gn0_gn10_gn11_gk1, + make_multi_index(0, 0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS + // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS + // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_block_desc_gk0_bm_gk1), + decltype(b_block_desc_gk0_bn_gk1), + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + BM1PerThreadBM11, + BN1PerThreadBN11>{}; + + constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); + + ThreadwiseTensorSliceSet_v1{} + .Run(c_thread_desc_bm0_bm1_bn0_bn1, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); + + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t gk0_block_on_grid = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); + + gk0_block_on_grid += 2 * GK0PerBlock; + } while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, + a_block_slice_copy_step, + AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, + b_block_slice_copy_step, + BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead( + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), + decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1), + Sequence<1, + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1], + 1, + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + make_multi_index(igm10, + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0], + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1], + ign10, + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2], + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])} + .Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_buf, + CGridStepHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp new file mode 100644 index 0000000000..d91159b884 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r2.hpp @@ -0,0 +1,662 @@ +#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP +#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_dlops_v2r2.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AKM0M1GridDesc a_k_m0_m1_grid_desc, + const BKN0N1GridDesc b_k_n0_n1_grid_desc, + const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +// pass tensor descriptor by CONSTANT void pointer +// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to +// non-modifiable parameter address space, so compiler can enable corresponding optimization +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k_m0_m1_grid_desc, + const void CONSTANT* p_b_k_n0_n1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + // first cast void CONSTANT void* to void* + // second cast void* to Desc* + // the copy constructor of tensor descriptor doesn't take address_space(4) + const auto a_k_m0_m1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc)); + const auto b_k_n0_n1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc)); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc)); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor)); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#endif + +template +struct GridwiseGemmDlops_km_kn_mn_v1r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K == b_k_n_grid_desc.GetLength(I0)) && + (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K) + { + const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc) + { + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor( + a_k_m_grid_desc, + make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc) + { + const auto K = b_k_n_grid_desc.GetLength(I0); + const auto N = b_k_n_grid_desc.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor( + b_k_n_grid_desc, + make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_m0_m10_m11_n0_n10_n11_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); + using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{})); + using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AKM0M1GridDesc& a_k_m0_m1_grid_desc, + const BKN0N1GridDesc& b_k_n0_n1_grid_desc, + const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + + const auto K = a_k_m0_m1_grid_desc.GetLength(I0); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K_M0_M1, + ABlockTransferThreadClusterLengths_K_M0_M1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k_m0_m1_grid_desc), + decltype(a_k_m0_m1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_M1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_k_m0_m1_grid_desc, + make_multi_index(0, im0, 0), + a_k_m0_m1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K_N0_N1, + BBlockTransferThreadClusterLengths_K_N0_N1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k_n0_n1_grid_desc), + decltype(b_k_n0_n1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_N1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_k_n0_n1_grid_desc, + make_multi_index(0, in0, 0), + b_k_n0_n1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlockM1] is in LDS + // b_mtx[KPerBlocl, NPerBlockN1] is in LDS + // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + + constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + + ThreadwiseTensorSliceSet_v1{} + .Run(c_m10_m11_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{}; + constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k_m0_m1_global_move_slice_window_step_hack = + AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k_n0_n1_global_move_slice_window_step_hack = + BGridMoveSliceWindowStepHacks{}; + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_k_n0_n1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); + + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_step_hack); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), + decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} + .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_grid_buf, + CGridStepHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp new file mode 100644 index 0000000000..2653dd4340 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v1r3.hpp @@ -0,0 +1,650 @@ +#ifndef CK_GRIDWISE_GEMM_V1R3_HPP +#define CK_GRIDWISE_GEMM_V1R3_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_dlops_v2r3.hpp" +#include "blockwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v1r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, + const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, + const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +// pass tensor descriptor by CONSTANT void pointer +// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to +// non-modifiable parameter address space, so compiler can enable corresponding optimization +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc, + const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + // first cast void CONSTANT void* to void* + // second cast void* to Desc* + // the copy constructor of tensor descriptor doesn't take address_space(4) + const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc)); + const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc)); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc)); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor)); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#endif + +template +struct GridwiseGemmDlops_km_kn_mn_v1r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && + K1 == a_k0_m_k1_grid_desc.GetLength(I2) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && + (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) + { + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_k0_m0_m1_k1_grid_desc = + transform_tensor_descriptor(a_k0_m_k1_grid_desc, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_k0_m0_m1_k1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) + { + const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_k0_n0_n1_k1_grid_desc = + transform_tensor_descriptor(b_k0_n_k1_grid_desc, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_k0_n0_n1_k1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_m0_m10_m11_n0_n10_n11_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); + using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); + using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, + const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, + const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperationEnum_t::Set, + Sequence, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m0_m1_k1_grid_desc), + decltype(a_k0_m0_m1_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_k0_m0_m1_k1_grid_desc, + make_multi_index(0, im0, 0, 0), + a_k0_m0_m1_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperationEnum_t::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n0_n1_k1_grid_desc), + decltype(b_k0_n0_n1_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_k0_n0_n1_k1_grid_desc, + make_multi_index(0, in0, 0, 0), + b_k0_n0_n1_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlockM1] is in LDS + // b_mtx[KPerBlocl, NPerBlockN1] is in LDS + // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + + ThreadwiseTensorSliceSet_v1{} + .Run(c_m10_m11_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K0 - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow( + a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); + b_blockwise_copy.MoveSrcSliceWindow( + b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{}); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), + decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} + .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_grid_buf, + CGridStepHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp new file mode 100644 index 0000000000..84ee6f40ec --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp @@ -0,0 +1,458 @@ +#ifndef CK_GRIDWISE_GEMM_V2_HPP +#define CK_GRIDWISE_GEMM_V2_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "blockwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct GridwiseGemmDlops_km_kn_mn_v3 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto E = EPerBlock * 3 * 3; + + constexpr auto max_lds_align = + math::lcm(Number{}, Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align); + + return a_block_space_size * sizeof(FloatAB); + } + + template + __device__ void Run(const AGlobalDesc& a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + FloatAB* __restrict__ p_shared_block, + integral_constant, + integral_constant) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_e_k_global_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); + + constexpr auto E = EPerBlock * 3 * 3; + + // const auto E = a_e_k_global_desc.GetLength(I0); + const auto K = a_e_k_global_desc.GetLength(I1); + + const auto N = b_e_n_ho_wo_global_desc.GetLength(I1); + const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); + const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); + +// divide block work by [M, N] +#if 0 + const auto ho_block_work_num = Ho / Number{}; + const auto wo_block_work_num = Wo / Number{}; + const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; + + const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num; + const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; + + const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num; + const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; +#else + // Hack: this force result into SGPR + const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); + const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); + const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; + + const index_t k_block_work_id = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num); + const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; + + const index_t ho_block_work_id = + __builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num); + const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; +#endif + + // lds max alignment + constexpr auto max_lds_align = + math::lcm(Number{}, Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + auto blockwise_gemm = + BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; + + auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const auto k_thread_id = c_thread_mtx_index.k; + const auto ho_thread_id = c_thread_mtx_index.h; + const auto wo_thread_id = c_thread_mtx_index.w; + + const index_t k_block_data_on_global = k_block_work_id * KPerBlock; + const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock; + const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock; + + const index_t ho_thread_data_on_global = + ho_block_data_on_global + ho_thread_id * HoPerThread; + const index_t wo_thread_data_on_global = + wo_block_data_on_global + wo_thread_id * WoPerThread; + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_e_k_global_desc), + decltype(a_e_k_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1>, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_e_k_global_desc, + make_multi_index(0, k_block_data_on_global), + a_e_k_desc, + make_multi_index(0, 0)); + + constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + auto b_threadwise_transfer = + ThreadwiseTensorSliceTransfer_v2, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + 1, + true>( + b_e_n_ho_wo_global_desc, + make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); + + auto a_block_buf = make_dynamic_buffer( + p_shared_block, a_e_k_desc.GetElementSpaceSize()); + + // register allocation for output + StaticBuffer + c_thread_buf; + + // initialize output thread tensor + ThreadwiseTensorSliceSet_v1>{} + .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); + + constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{}; + constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{}; + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = + BGlobalMoveSliceWindowStepHacks{}; + + // double regsiter buffer for b + StaticBuffer + b_thread_even_buf, b_thread_odd_buf; + + // LDS double buffer: preload data + { + a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_even_buf, + b_e_n_ho_wo_global_step_hacks); + + a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); + } + + __syncthreads(); + + if constexpr(HasMainKBlockLoop) + { + index_t e_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_odd_buf, + b_e_n_ho_wo_global_step_hacks); + + // LDS double buffer: GEMM on current data + // TODO: @Zhang Jing: blockwise gemm should be able to move slice window + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_even_buf, + b_e_n_ho_wo_global_step_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + e_block_data_begin += 2 * EPerBlock; + + } while(e_block_data_begin < E - 2 * EPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + b_global_buf, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + b_thread_odd_buf, + b_e_n_ho_wo_global_step_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + + blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); + } + + // output: register to global memory + { + // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{}; + + const index_t k_thread_data_on_global = + k_block_data_on_global + k_thread_id * KPerThread; + + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>( + c_k_n_ho_wo_global_desc, + make_multi_index( + k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) + .Run(c_k_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + c_k_n_ho_wo_global_desc, + c_global_buf, + c_k_n_ho_wo_global_tensor_step_hacks); + } + } + + // pass tensor descriptor by reference + template + __device__ void Run(const AGlobalDesc& a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + p_shared_block, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by their pointers + template + __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc* p_b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc* p_c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_e_k_global_desc = *p_a_e_k_global_desc; + const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc; + const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc; + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by void* + template + __device__ void Run(const void* p_a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const void* p_b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const void* p_c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_e_k_global_desc = *reinterpret_cast(p_a_e_k_global_desc); + const auto b_e_n_ho_wo_global_desc = + *reinterpret_cast(p_b_e_n_ho_wo_global_desc); + const auto c_k_n_ho_wo_global_desc = + *reinterpret_cast(p_c_k_n_ho_wo_global_desc); + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp new file mode 100644 index 0000000000..207f73072f --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -0,0 +1,799 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AK0MK1GridDesc a_k0_m_k1_grid_desc, + const BK0NK1GridDesc b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + const auto a_k0_m_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); + const auto b_k0_n_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); + const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc)); + const auto c_block_cluster_adaptor = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); +} +#endif + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + // TODO: turn on this + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && + K1 == a_k0_m_k1_grid_desc.GetLength(I2) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) && + (MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr auto + MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + constexpr auto xdlops_gemm = XdlopsGemm{}; + + constexpr auto CLayout = xdlops_gemm.GetCLayout(); + + constexpr auto M0 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto M2 = Number{}; + + constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); + constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + + constexpr auto N1 = Number{}; + + const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), + make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); + + return c_m0_m1_m2_n_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + +#if 1 + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); +#elif 1 + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))), + make_tuple(Sequence<1, 0>{}), + make_tuple(Sequence<0>{})); +#endif + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_grid_desc), + decltype(a_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_k0_m_k1_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_k0_m_k1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_grid_desc), + decltype(b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_k0_n_k1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && + NPerBlock % (NPerWave * NRepeat) == 0, + "wrong!"); + + constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor( + a_k0_m_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor( + b_k0_n_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto blockwise_gemm = + BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; + + constexpr auto CLayout = blockwise_gemm.GetCLayout(); + + constexpr index_t BlkSize = CLayout.GetBlkSize(); + constexpr index_t NumBlks = CLayout.GetNumBlks(); + constexpr index_t NumXdlops = CLayout.GetNumXdlops(); + + static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); + + constexpr auto c_mr_nr_blk_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + + StaticBuffer, + c_mr_nr_blk_desc.GetElementSpaceSize(), + true> + c_thread_buf; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; + constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + } + + // main body + index_t k_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_step_hack); + + a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + + k_block_data_begin += KPerBlock; + } while(k_block_data_begin < (K0 - KPerBlock)); + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + +#if 0 + // output: register to global memory + { + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr index_t N0 = CLayout.N1(); + constexpr index_t N1 = CLayout.N0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number<1>{}, + Number<1>{}, + Number{}, + Number<1>{}, + Number{}, + Number<1>{})); + + StaticBuffer + c_blk_buf_; + + static_for<0, MRepeat, 1>{}([&](auto mr_i) { + static_for<0, NRepeat, 1>{}([&](auto nr_i) { + constexpr auto blk_off = + c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i)); + + static_for<0, BlkSize, 1>{}([&](auto j) { + c_blk_buf_(Number{}) = + c_thread_buf[Number{}] + .template AsType()[Number{}]; + }); + }); + }); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; + + constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); + constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(c_m0_m1_m2_n_thread_desc), + decltype(c_m0_m1_m2_n_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_m0_m1_m2_n_grid_desc, + make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves), + n_thread_data_on_grid / (N1 * NWaves), + m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0), + n_thread_data_on_grid % (N1 * NWaves) / N1, + m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1), + m_thread_data_on_grid % (M2 * M1) / M2, + m_thread_data_on_grid % M2, + n_thread_data_on_grid % N1)} + .Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_blk_buf_, + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_step_hacks); + } +#else + { + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_m0_m1_m2_n_grid_desc, + make_multi_index(0, + 0, + 0, + 0, + m_thread_data_on_grid / (M2 * M1), + m_thread_data_on_grid % (M2 * M1) / M2, + m_thread_data_on_grid % M2, + n_thread_data_on_grid)}; + + auto init_copy = [&](auto c_thread_idx_) { + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_step_hacks); + + return c_thread_idx_; + }; + + auto mrepeat_plus_copy = [&](auto c_thread_idx_) { + constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_step_hacks); + }; + + auto nrepeat_plus_copy = [&](auto c_thread_idx_) { + constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_step_hacks); + }; + + auto mrepeat_minus_copy = [&](auto c_thread_idx_) { + constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_step_hacks); + }; + + auto nrepeat_minus_copy = [&](auto c_thread_idx_) { + constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_step_hacks); + }; + + static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or + (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or + (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or + (MRepeat == 1 && NRepeat == 1), + "wrong"); + + if constexpr(MRepeat == 4 && NRepeat == 4) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + nrepeat_plus_copy(make_tuple(I0, I3)); + mrepeat_plus_copy(make_tuple(I1, I3)); + nrepeat_minus_copy(make_tuple(I1, I2)); + nrepeat_minus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + nrepeat_plus_copy(make_tuple(I2, I1)); + nrepeat_plus_copy(make_tuple(I2, I2)); + nrepeat_plus_copy(make_tuple(I2, I3)); + mrepeat_plus_copy(make_tuple(I3, I3)); + nrepeat_minus_copy(make_tuple(I3, I2)); + nrepeat_minus_copy(make_tuple(I3, I1)); + nrepeat_minus_copy(make_tuple(I3, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + mrepeat_plus_copy(make_tuple(I3, I0)); + nrepeat_plus_copy(make_tuple(I3, I1)); + mrepeat_minus_copy(make_tuple(I2, I1)); + mrepeat_minus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + mrepeat_plus_copy(make_tuple(I1, I2)); + mrepeat_plus_copy(make_tuple(I2, I2)); + mrepeat_plus_copy(make_tuple(I3, I2)); + nrepeat_plus_copy(make_tuple(I3, I3)); + mrepeat_minus_copy(make_tuple(I2, I3)); + mrepeat_minus_copy(make_tuple(I1, I3)); + mrepeat_minus_copy(make_tuple(I0, I3)); + } + } + else if constexpr(MRepeat == 4 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + mrepeat_plus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + nrepeat_plus_copy(make_tuple(I2, I1)); + mrepeat_plus_copy(make_tuple(I3, I1)); + nrepeat_minus_copy(make_tuple(I3, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + mrepeat_plus_copy(make_tuple(I3, I0)); + nrepeat_plus_copy(make_tuple(I3, I1)); + mrepeat_minus_copy(make_tuple(I2, I1)); + mrepeat_minus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 4) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + nrepeat_plus_copy(make_tuple(I0, I3)); + mrepeat_plus_copy(make_tuple(I1, I3)); + nrepeat_minus_copy(make_tuple(I1, I2)); + nrepeat_minus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + nrepeat_plus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + mrepeat_plus_copy(make_tuple(I1, I2)); + nrepeat_plus_copy(make_tuple(I1, I3)); + mrepeat_minus_copy(make_tuple(I0, I3)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + mrepeat_plus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + nrepeat_plus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 1) + { + init_copy(make_tuple(I0, I0)); + mrepeat_plus_copy(make_tuple(I1, I0)); + } + else if constexpr(MRepeat == 1 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + nrepeat_plus_copy(make_tuple(I0, I1)); + } + else if constexpr(MRepeat == 1 && NRepeat == 1) + { + init_copy(make_tuple(I0, I0)); + } + } +#endif + } +}; // namespace ck + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp new file mode 100644 index 0000000000..a925a5cd68 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp @@ -0,0 +1,229 @@ +#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP +#define CK_THREADWISE_CONTRACTION_DLOPS_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1] +// Tensor element can be vectorized data +// Assume: +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 +{ + __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths + + // TODO remove this restriction + static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto TK = TKLengths{}[I0]; + constexpr auto TM0 = TMLengths{}[I0]; + constexpr auto TM1 = TMLengths{}[I1]; + constexpr auto TN0 = TNLengths{}[I0]; + constexpr auto TN1 = TNLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK, 1>{}([&](auto tk) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk, tm0, tm1)); + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk, tn0, tn1)); + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1] +// Tensor element can be vectorized data +// Assume: +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +{ + __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths + + // TODO remove this restriction + static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr index_t TK0 = TKLengths{}[I0]; + constexpr index_t TK1 = TKLengths{}[I1]; + constexpr index_t TM0 = TMLengths{}[I0]; + constexpr index_t TM1 = TMLengths{}[I1]; + constexpr index_t TN0 = TNLengths{}[I0]; + constexpr index_t TN1 = TNLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK0, 1>{}([&](auto tk0) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { + vector_type a_vec; + vector_type b_vec; + + static_for<0, TK1, 1>{}([&](auto tk1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); + + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); + + a_vec.template AsType()(tk1) = a_buf[Number{}]; + b_vec.template AsType()(tk1) = b_buf[Number{}]; + }); + + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; + + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + inner_product( + a_vec.template AsType()[I0], + b_vec.template AsType()[I0], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp new file mode 100644 index 0000000000..015ad675fb --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp @@ -0,0 +1,160 @@ +#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP +#define CK_THREADWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +// C[M, N] += transpose(A[K, M]) * B[K, N] +// Element of matrix can be vectorized data +// Assume: +// 1. ADesc, BDesc, CDesc are known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDlops_km_kn_mn_v3 +{ + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && + CDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto E = ADesc{}.GetLength(I0); + constexpr auto K = ADesc{}.GetLength(I1); + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, E, 1>{}([&](auto e) { + static_for<0, K, 1>{}([&](auto k) { + constexpr index_t a_offset = + ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k)); + + if constexpr(H == 2 && W == 2) + { + constexpr index_t b_offset_0 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); + constexpr index_t b_offset_1 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1)); + constexpr index_t b_offset_2 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); + constexpr index_t b_offset_3 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1)); + + constexpr index_t c_offset_0 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); + constexpr index_t c_offset_1 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1)); + constexpr index_t c_offset_2 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); + constexpr index_t c_offset_3 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1)); + + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + } + else if constexpr(H == 4 && W == 1) + { + constexpr index_t b_offset_0 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0)); + constexpr index_t b_offset_1 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0)); + constexpr index_t b_offset_2 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0)); + constexpr index_t b_offset_3 = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0)); + + constexpr index_t c_offset_0 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0)); + constexpr index_t c_offset_1 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0)); + constexpr index_t c_offset_2 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0)); + constexpr index_t c_offset_3 = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0)); + + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + } + else + { + static_for<0, H, 1>{}([&](auto h) { + static_for<0, W, 1>{}([&](auto w) { + constexpr index_t b_offset = + BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w)); + + constexpr index_t c_offset = + CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w)); + +#if 0 + c_buf(Number{}) += inner_product_with_conversion{}( + a_buf[Number{}], b_buf[Number{}]); +#else + amd_assembly_inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); +#endif + }); + }); + } + }); + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp new file mode 100644 index 0000000000..0c7aa978a7 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp @@ -0,0 +1,59 @@ +#ifndef CK_THREADWISE_TENSOR_SET_HPP +#define CK_THREADWISE_TENSOR_SET_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Assume: +// 1. Desc is known at compile-time +// 2. Buffer is StaticBuffer +// 3. OriginIdx is known at compile-time +// 4. use #-step +template ::type = false> +struct ThreadwiseTensorSliceSet_v1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + template + __device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const + { + static_assert(Desc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>>::value, + "wrong! OriginIdx need to be known at compile-time"); + + // Desc is known at compile-time + constexpr auto desc = remove_cv_t>{}; + + // OriginIdx is known at compile-time + constexpr auto origin_idx = to_multi_index(OriginIdx{}); + + static_ford{}([&](auto access_idx) { + constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx); + + constexpr bool is_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); + + constexpr index_t offset = coord.GetOffset(); + + if constexpr(is_valid) + { + buf(Number{}) = initial_value; + } + }); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp new file mode 100644 index 0000000000..0071accf7f --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -0,0 +1,1437 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? ScalarPerVector : 1; + } +}; + +template +struct lambda_scalar_step_in_vector +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? 1 : 0; + } +}; +} // namespace detail + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is StaticBuffer +// 3. SrcSliceOrginIdx is known at compile-time +// 2. dst: +// 1. DstDesc is not known at compile-time +// 2. DstBuffer is DynamicBuffer +// 3. DstSliceOrginIdx is not known at compile time +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v1r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, + const Index& dst_slice_origin_idx) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + const DstStepHacks& dst_step_hacks) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert( + is_known_at_compile_time>>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + // static_assert(is_same>, + // remove_cv_t>>::value, + //"wrong! SrcBuffer data type is wrong"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; + }(); + + typename vector_type_maker::type dst_vector; + + using dst_vector_t = + typename vector_type_maker::type::type; + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); + + dst_vector.template AsType()(i) = + type_convert{}(src_buf[Number{}]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks); + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in Run(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + DstCoord dst_coord_; +}; // namespace ck + +// Assume: +// 1. src: +// 1. SrcDesc is not known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_slice_origin_idx is not known at compile-time +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. dst_slice_origin_idx is known at compile-time +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, + const Index& src_slice_origin_idx) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + } + + __device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf, + const SrcStepHacks& src_step_hacks) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert( + is_known_at_compile_time>>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + src_scalar_per_access; + }(); + + typename vector_type_maker::type src_vector; + + using src_vector_t = + typename vector_type_maker::type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector + src_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = src_vector.template AsType()[i]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in Run(), if it has not being reset by + // RunWrite() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; +}; // namespace ck + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) + { + // TODO: fix this + static_assert(is_same::value, + "wrong! current implementation assume SrcData and DstData are same type"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_tmp_vector + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_tmp_vector to buffer_ + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); + + buffer_(Number{}) = src_tmp_vector.template AsType()[i]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + vector_type_maker_t dst_tmp_vector; + + // copy data from buffer_ to dst_tmp_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); + + dst_tmp_vector.template AsType()(i) = buffer_[Number{}]; + }); + + using dst_vector_t = typename decltype(dst_tmp_vector)::type; + + // copy data from dst_tmp_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_tmp_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_step_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_step_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong!"); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert( + is_known_at_compile_time< + remove_cv_t>>::value && + is_known_at_compile_time>>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cv_t>{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // scalar per access of each dim + constexpr auto src_scalar_per_access = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number{}; + } + else + { + return Number<1>{}; + } + }, + Number{}); + + // scalar step (if steping on SrcVectorDim) of each dim + constexpr auto src_scalar_step_in_vector = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number<1>{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { +#if 0 + // TODO: unable to compile + // position in slice window + constexpr auto data_to_origin_disp_idx = + container_reorder_given_old2new(ordered_access_idx, dim_access_order) * + src_scalar_per_access; +#else + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; +#endif + // src coordinate + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_tmp_vector + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert{}(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp new file mode 100644 index 0000000000..ccac4b7b44 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp @@ -0,0 +1,779 @@ +#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP +#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v3r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) + { + // TODO: fix this + static_assert(is_same::value, + "wrong! current implementation assume SrcData and DstData are same type"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 && + SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0, + "wrong!"); + }); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + }(); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_vector to buffer_ + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx); + + buffer_(Number{}) = + src_vector.template AsType()[Number{}]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // tensor descriptor for dst_vector + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(dst_vector_tensor_lengths, + DstVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + DstVectorTensorContiguousDimOrder{}); + + constexpr auto dst_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(dst_vector_tensor_lengths), + sequence_to_tuple_of_number(dst_vector_tensor_strides)); + + // dst access order and lengths + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + }(); + + vector_type_maker_t dst_vector; + + // copy data from buffer_ to dst_vector (also cast from SrcData to DstData) + static_ford{}([&](auto dst_vector_idx_) { + constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx); + + constexpr index_t dst_vector_offset = + dst_vector_desc.CalculateOffset(dst_vector_idx); + + dst_vector.template AsType()(Number{}) = + type_convert{}(buffer_[Number{}]); + }); + + using dst_vector_t = typename decltype(dst_vector)::type; + + // copy data from dst_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_step_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_step_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); + }); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert( + is_known_at_compile_time< + remove_cv_t>>::value && + is_known_at_compile_time>>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cv_t>{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * + src_vector_tensor_lengths; + + // src coordinate at starting point of src_vector + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf (also cast from SrcData to DstData) + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); + + dst_buf(Number{}) = type_convert{}( + src_vector.template AsType()[Number{}]); + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp new file mode 100644 index 0000000000..affe096ace --- /dev/null +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -0,0 +1,801 @@ +#ifndef CK_XDLOPS_GEMM_HPP +#define CK_XDLOPS_GEMM_HPP + +#include "common_header.hpp" +#include "math.hpp" +#include "amd_xdlops.hpp" + +namespace ck { + +enum struct mfma_instr +{ + /// fp32 + mfma_f32_32x32x1xf32 = 0, + mfma_f32_16x16x1xf32, + mfma_f32_4x4x1xf32, + mfma_f32_32x32x2xf32, // k reduction + mfma_f32_16x16x4xf32, // k reduction + /// fp16 + mfma_f32_32x32x4f16, + mfma_f32_16x16x4f16, + mfma_f32_4x4x4f16, + mfma_f32_32x32x8f16, // k reduction + mfma_f32_16x16x16f16, // k reduction + /// bfp16 + mfma_f32_32x32x2bf16, + mfma_f32_16x16x2bf16, + mfma_f32_4x4x2bf16, + mfma_f32_32x32x4bf16, // k reduction + mfma_f32_16x16x8bf16, // k reduction +}; + +template +struct mfma_info; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 1; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 2; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 4; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 1; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); + } +}; + +// treat 4x4x1 as a single-blk 4x64 mfma +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 1; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 4; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 8; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 16; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 4; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 4; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); + } +}; + +#if 0 +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 2; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); + + return intrin_mfma_f32_32x32x2bf16::run( + p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 4; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); + + return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 8; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); + + return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 2; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); + + return intrin_mfma_f32_16x16x2bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 2; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = c_style_pointer_cast(a); + const auto p_b = c_style_pointer_cast(b); + + return intrin_mfma_f32_4x4x2bf16::run(p_a, p_b, reg_c); + } +}; +#endif + +template +struct xdlops_info +{ + static constexpr auto mfma_type = mfma_info{}; + + static constexpr index_t MPerXdlops = MPerXdlops_; + static constexpr index_t NPerXdlops = NPerXdlops_; + + static constexpr bool IsABroadcast() + { + static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); + return true; + } + + static constexpr bool IsKReduction() + { + return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); + } + + static constexpr index_t GetKPerXdlops() + { + return IsKReduction() ? mfma_type.num_input_blks : 1; + } + + static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } +}; + +template +struct XdlopsGemm +{ + template + static constexpr auto GetXdlopsInfo(); + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + +#if 0 + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } +#endif + + using CIndex = MultiIndex<2>; + + __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + } + + __host__ __device__ constexpr XdlopsGemm() + { + static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 || + NPerXdlops == 64, + "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 || + MPerXdlops == 64, + "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); + static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, + "m != num_input_blks * num_regs_blk"); + static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks || + mfma_type.num_output_blks == 1, + "incorrect num_output_blks"); + static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n, + "num_regs_blk incorrect"); + + static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / mfma_type.wave_size; + } + + template + __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "base base_type must be float, half, ushort!"); + + static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); + + constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); + + static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { + constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); + constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); + + mfma_type.template run( + p_a_wave[Number{}], + p_b_wave[Number{}], + p_c_thread); + }); + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; + const index_t blk_id = laneId / mfma_type.num_threads_blk; + const index_t blk_td = laneId % mfma_type.num_threads_blk; + + index_t n_offset = blk_i * mfma_type.n + blk_td; + index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; + + return CIndex{m_offset, n_offset}; + } + + static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; + static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; + static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; + static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; + + static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); + static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); + static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); + + static constexpr auto GetBlkId(const index_t lane_id) + { + return lane_id / mfma_type.num_threads_blk; + } + + static constexpr auto GetBlkTd(const index_t lane_id) + { + return lane_id % mfma_type.num_threads_blk; + } + + static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; + + struct CLayout + { + __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } + __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } + __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } + __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } + + __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } + + __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + } + }; + + __host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_address_space.hpp b/composable_kernel/include/utility/amd_address_space.hpp new file mode 100644 index 0000000000..24c95b27af --- /dev/null +++ b/composable_kernel/include/utility/amd_address_space.hpp @@ -0,0 +1,44 @@ +#ifndef CK_AMD_ADDRESS_SPACE_HPP +#define CK_AMD_ADDRESS_SPACE_HPP + +#include "config.hpp" +#include "c_style_pointer_cast.hpp" + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +namespace ck { + +enum AddressSpaceEnum_t +{ + Generic, + Global, + Lds, + Sgpr, + Vgpr, +}; + +template +__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p) +{ + // cast a pointer in "Constant" address space (4) to "Generic" address space (0) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +template +__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p) +{ + // cast a pointer in "Generic" address space (0) to "Constant" address space (4) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T CONSTANT*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp new file mode 100644 index 0000000000..57081b7fd7 --- /dev/null +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -0,0 +1,681 @@ +#ifndef CK_AMD_BUFFER_ADDRESSING_HPP +#define CK_AMD_BUFFER_ADDRESSING_HPP + +#include "data_type.hpp" + +namespace ck { + +template +union BufferResource +{ + // 128 bit SGPRs to supply buffer resource in buffer instructions + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t content; + StaticallyIndexedArray address; + StaticallyIndexedArray range; + StaticallyIndexedArray config; +}; + +template +__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T); + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +// load +__device__ int8_t +llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); + +__device__ int8x2_t +llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); + +__device__ int8x4_t +llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); + +__device__ int16_t +llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); +__device__ int32_t +llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + +__device__ int32x2_t +llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); + +__device__ int32x4_t +llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); +// half +__device__ half_t +llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +__device__ half2_t +llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +__device__ half4_t +llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); + +// float +__device__ float +llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +__device__ float2_t +llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); + +__device__ float4_t +llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); + +// store +__device__ void +llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +// half +__device__ void +llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); +// float +__device__ void +llvm_amdgcn_raw_buffer_store_fp32(float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); + +template +__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_fp32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_fp32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + 0); + + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_fp16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_fp16x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_fp16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { +#if 0 + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(half_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + 0); + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i8( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + return llvm_amdgcn_raw_buffer_load_i8x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + else if constexpr(N == 4) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + return llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); +#else + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + else if constexpr(N == 8) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + else if constexpr(N == 16) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + vector_type tmp; + + tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + 0); + + tmp.AsType()(Number<2>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int8_t), + 0); + + tmp.AsType()(Number<3>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int8_t), + 0); + + return tmp.AsType()(Number<0>{}); +#else + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif + } + } +} + +template +__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + llvm_amdgcn_raw_buffer_store_i16(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + else if constexpr(N == 4) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#else + llvm_amdgcn_raw_buffer_store_i32(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); +#endif + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_i32x2(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 16) + { + llvm_amdgcn_raw_buffer_store_i32x4(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + 0); + } + } +} + +// buffer_load requires: +// 1) p_src_wave must be in global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK + uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff; + + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#else + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + + return src_thread_element_valid ? tmp : vector_t(0); +#endif +} + +// buffer_load requires: +// 1) p_src_wave must be in global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size, + T customized_value) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + + return src_thread_element_valid ? tmp : vector_t(customized_value); +} + +// buffer_store requires: +// 1) p_dst_wave must be global memory +// 2) p_dst_wave to be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp new file mode 100644 index 0000000000..a2d9d5f062 --- /dev/null +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -0,0 +1,356 @@ +#ifndef CK_AMD_INLINE_ASM_HPP +#define CK_AMD_INLINE_ASM_HPP + +#include "data_type.hpp" +#include "c_style_pointer_cast.hpp" + +// TODO: deprecate all amd_assembly_outer_product_xxx + +namespace ck { + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) +{ + asm volatile("\n \ + v_fmac_f32 %0, %2, %3 \n \ + v_fmac_f32 %1, %2, %4 \n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4( + float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) +{ + asm volatile("\n \ + v_fmac_f32 %0, %4, %5 \n \ + v_fmac_f32 %1, %4, %6 \n \ + v_fmac_f32 %2, %4, %7 \n \ + v_fmac_f32 %3, %4, %8 \n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) +{ + asm volatile("\n \ + v_dot2_f32_f16 %0, %2, %3, %0\n \ + v_dot2_f32_f16 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) +{ + // TODO remove pointer casting + const half2_t* p_a_half2 = c_style_pointer_cast(&a); + const half2_t* p_b0_half2 = c_style_pointer_cast(&b0); + const half2_t* p_b1_half2 = c_style_pointer_cast(&b1); + + // do dot2 two times + asm volatile("\n \ + v_dot2_f32_f16 %0, %2, %4, %0\n \ + v_dot2_f32_f16 %1, %2, %6, %1\n \ + v_dot2_f32_f16 %0, %3, %5, %0\n \ + v_dot2_f32_f16 %1, %3, %7, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(p_a_half2[0]), + "v"(p_a_half2[1]), + "v"(p_b0_half2[0]), + "v"(p_b0_half2[1]), + "v"(p_b1_half2[0]), + "v"(p_b1_half2[1]), + "0"(c0), + "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(half2_t a, + half2_t b0, + half2_t b1, + half2_t b2, + half2_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + asm volatile("\n \ + v_dot2_f32_f16 %0, %4, %5, %0\n \ + v_dot2_f32_f16 %1, %4, %6, %1\n \ + v_dot2_f32_f16 %2, %4, %7, %2\n \ + v_dot2_f32_f16 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(half4_t a, + half4_t b0, + half4_t b1, + half4_t b2, + half4_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + // TODO remove pointer casting + const half2_t* p_a_half2 = c_style_pointer_cast(&a); + const half2_t* p_b0_half2 = c_style_pointer_cast(&b0); + const half2_t* p_b1_half2 = c_style_pointer_cast(&b1); + const half2_t* p_b2_half2 = c_style_pointer_cast(&b2); + const half2_t* p_b3_half2 = c_style_pointer_cast(&b3); + + // do dot2 two times + asm volatile("\n \ + v_dot2_f32_f16 %0, %4, %6, %0\n \ + v_dot2_f32_f16 %1, %4, %8, %1\n \ + v_dot2_f32_f16 %2, %4, %10, %2\n \ + v_dot2_f32_f16 %3, %4, %12, %3\n \ + v_dot2_f32_f16 %0, %5, %7, %0\n \ + v_dot2_f32_f16 %1, %5, %9, %1\n \ + v_dot2_f32_f16 %2, %5, %11, %2\n \ + v_dot2_f32_f16 %3, %5, %13, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(p_a_half2[0]), + "v"(p_a_half2[1]), + "v"(p_b0_half2[0]), + "v"(p_b0_half2[1]), + "v"(p_b1_half2[0]), + "v"(p_b1_half2[1]), + "v"(p_b2_half2[0]), + "v"(p_b2_half2[1]), + "v"(p_b3_half2[0]), + "v"(p_b3_half2[1]), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); +} + +__device__ void amd_assembly_outer_product_1x4(half8_t a, + half8_t b0, + half8_t b1, + half8_t b2, + half8_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + + // TODO remove pointer casting + const half4_t* p_a_half4 = c_style_pointer_cast(&a); + const half4_t* p_b0_half4 = c_style_pointer_cast(&b0); + const half4_t* p_b1_half4 = c_style_pointer_cast(&b1); + const half4_t* p_b2_half4 = c_style_pointer_cast(&b2); + const half4_t* p_b3_half4 = c_style_pointer_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3); +} + +__device__ void amd_assembly_outer_product_1x4(half16_t a, + half16_t b0, + half16_t b1, + half16_t b2, + half16_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + // TODO remove pointer casting + const half8_t* p_a_half8 = c_style_pointer_cast(&a); + const half8_t* p_b0_half8 = c_style_pointer_cast(&b0); + const half8_t* p_b1_half8 = c_style_pointer_cast(&b1); + const half8_t* p_b2_half8 = c_style_pointer_cast(&b2); + const half8_t* p_b3_half8 = c_style_pointer_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %2, %3, %0\n \ + v_dot4_i32_i8 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(as_type(a)), + "v"(as_type(b0)), + "v"(as_type(b1)), + "0"(c0), + "1"(c1)); +#else + c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); +#endif +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(int8x4_t a, + int8x4_t b0, + int8x4_t b1, + int8x4_t b2, + int8x4_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %4, %5, %0\n \ + v_dot4_i32_i8 %1, %4, %6, %1\n \ + v_dot4_i32_i8 %2, %4, %7, %2\n \ + v_dot4_i32_i8 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(as_type(a)), + "v"(as_type(b0)), + "v"(as_type(b1)), + "v"(as_type(b2)), + "v"(as_type(b3)), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); +#else + c0 = __builtin_amdgcn_sdot4(as_type(a), as_type(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(as_type(a), as_type(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(as_type(a), as_type(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(as_type(a), as_type(b3), c3, false); +#endif +} + +__device__ void amd_assembly_outer_product_1x4(int8x8_t a, + int8x8_t b0, + int8x8_t b1, + int8x8_t b2, + int8x8_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); +} + +__device__ void amd_assembly_outer_product_1x4(int8x16_t a, + int8x16_t b0, + int8x16_t b1, + int8x16_t b2, + int8x16_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I2], + vector_type{b0}.AsType()[I2], + vector_type{b1}.AsType()[I2], + vector_type{b2}.AsType()[I2], + vector_type{b3}.AsType()[I2], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I3], + vector_type{b0}.AsType()[I3], + vector_type{b1}.AsType()[I3], + vector_type{b2}.AsType()[I3], + vector_type{b3}.AsType()[I3], + c0, + c1, + c2, + c3); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_llvm_intrinsic.hpp b/composable_kernel/include/utility/amd_llvm_intrinsic.hpp new file mode 100644 index 0000000000..841d48f81c --- /dev/null +++ b/composable_kernel/include/utility/amd_llvm_intrinsic.hpp @@ -0,0 +1,11 @@ +#ifndef CK_AMD_LLVM_INTRINSIC_HPP +#define CK_AMD_LLVM_INTRINSIC_HPP + +#include "data_type.hpp" + +namespace ck { + +__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp new file mode 100644 index 0000000000..da74fe1d48 --- /dev/null +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -0,0 +1,499 @@ +#ifndef CK_AMD_XDLOPS_HPP +#define CK_AMD_XDLOPS_HPP + +#include "data_type.hpp" + +namespace ck { + +// A, B, C, cbsz, abid, blgp +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); + +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); + +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( + ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( + ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( + ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( + ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( + ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); + +template +struct intrin_mfma_f32_32x32x1f32; + +template +struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x2f32; + +template +struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f32; + +template +struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x1f32; + +template +struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 2, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32; + +template +struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4f16; + +template +struct intrin_mfma_f32_32x32x4f16<64, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x8f16; + +template +struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16f16; + +template +struct intrin_mfma_f32_16x16x16f16<16, 16, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f16; + +template +struct intrin_mfma_f32_16x16x4f16<16, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 2, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16; + +template +struct intrin_mfma_f32_4x4x4f16<4, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16<8, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 1, + 0); + } +}; + +#if 0 +template +struct intrin_mfma_f32_32x32x2bf16; + +template +struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride> +{ + __device__ static c_vec32_4_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + reg_c.s.z = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0); + reg_c.s.w = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride> +{ + __device__ static c_vec32_4_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + reg_c.s.z = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0); + reg_c.s.w = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride> +{ + __device__ static c_vec32_2_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride> +{ + __device__ static c_vec32_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride> +{ + __device__ static c_vec32_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + return reg_c; + } +}; + +__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); + return reg_c; +} + +__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec4_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); + return reg_c; +} + +template +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c); + +template <> +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0); + return reg_c; +} + +template <> +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); + return reg_c; +} + +template +struct intrin_mfma_f32_4x4x2bf16; + +template <> +struct intrin_mfma_f32_4x4x2bf16<4, 64> +{ + __device__ static c_vec4_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); + return reg_c; + } +}; + +template <> +struct intrin_mfma_f32_4x4x2bf16<8, 64> +{ + __device__ static c_vec4_2_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); + return reg_c; + } +}; + +#endif + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/array.hpp b/composable_kernel/include/utility/array.hpp new file mode 100644 index 0000000000..7271094d39 --- /dev/null +++ b/composable_kernel/include/utility/array.hpp @@ -0,0 +1,63 @@ +#ifndef CK_ARRAY_HPP +#define CK_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +template +struct Array +{ + using type = Array; + using data_type = TData; + + TData mData[NSize]; + + __host__ __device__ static constexpr index_t Size() { return NSize; } + + __host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; } + + __host__ __device__ constexpr TData& At(index_t i) { return mData[i]; } + + __host__ __device__ constexpr const TData& operator[](index_t i) const { return At(i); } + + __host__ __device__ constexpr TData& operator()(index_t i) { return At(i); } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } +}; + +// empty Array +template +struct Array +{ + using type = Array; + using data_type = TData; + + __host__ __device__ static constexpr index_t Size() { return 0; } +}; + +template +__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) +{ + using data_type = remove_cv_t>; + return Array{{std::forward(x), std::forward(xs)...}}; +} + +// make empty array +template +__host__ __device__ constexpr auto make_array() +{ + return Array{}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/array_multi_index.hpp b/composable_kernel/include/utility/array_multi_index.hpp new file mode 100644 index 0000000000..f692fb5143 --- /dev/null +++ b/composable_kernel/include/utility/array_multi_index.hpp @@ -0,0 +1,77 @@ +#ifndef CK_ARRAY_MULTI_INDEX_HPP +#define CK_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = Array; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +template +__host__ __device__ constexpr auto operator+=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator-=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator+(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator-(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator*(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; }); + return r; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/c_style_pointer_cast.hpp b/composable_kernel/include/utility/c_style_pointer_cast.hpp new file mode 100644 index 0000000000..8acf5790c6 --- /dev/null +++ b/composable_kernel/include/utility/c_style_pointer_cast.hpp @@ -0,0 +1,22 @@ +#ifndef CK_C_STYLE_POINTER_CAST_HPP +#define CK_C_STYLE_POINTER_CAST_HPP + +#include "type.hpp" +#include "enable_if.hpp" + +namespace ck { + +template && is_pointer_v, bool>::type = false> +__host__ __device__ PY c_style_pointer_cast(PX p_x) +{ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" +#pragma clang diagnostic ignored "-Wcast-align" + return (PY)p_x; // NOLINT(old-style-cast, cast-align) +#pragma clang diagnostic pop +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp new file mode 100644 index 0000000000..85c02a1b99 --- /dev/null +++ b/composable_kernel/include/utility/common_header.hpp @@ -0,0 +1,46 @@ +#ifndef CK_COMMON_HEADER_HPP +#define CK_COMMON_HEADER_HPP + +#include "config.hpp" +#include "array.hpp" +#include "container_helper.hpp" +#include "statically_indexed_array.hpp" +#include "container_element_picker.hpp" +#include "multi_index.hpp" +#include "data_type.hpp" +#include "data_type_enum.hpp" +#include "data_type_enum_helper.hpp" +#include "functional.hpp" +#include "functional2.hpp" +#include "functional3.hpp" +#include "functional4.hpp" +#include "enable_if.hpp" +#include "integral_constant.hpp" +#include "math.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "sequence_helper.hpp" +#include "synchronization.hpp" +#include "tuple.hpp" +#include "tuple_helper.hpp" +#include "type.hpp" +#include "magic_division.hpp" +#include "utility.hpp" +#include "c_style_pointer_cast.hpp" +#include "amd_address_space.hpp" +#include "amd_buffer_addressing.hpp" +#include "static_buffer.hpp" +#include "dynamic_buffer.hpp" + +#include "inner_product.hpp" + +// TODO: remove this +#if CK_USE_AMD_INLINE_ASM +#include "amd_inline_asm.hpp" +#endif + +#if CK_USE_AMD_XDLOPS +#include "amd_xdlops.hpp" +#endif + +#endif diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp new file mode 100644 index 0000000000..521ad24d47 --- /dev/null +++ b/composable_kernel/include/utility/config.hpp @@ -0,0 +1,134 @@ +#ifndef CK_CONFIG_AMD_HPP +#define CK_CONFIG_AMD_HPP + +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" +#endif +#include "bfloat16_dev.hpp" + +// "Constant" address space for kernel parameter +#define CONSTANT __attribute__((address_space(4))) + +// GPU target +// should enable one and only one GPU target +#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ + defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030)) +#error Need to define (only) one GPU target +#endif + +// launch bounds +#define CK_USE_LAUNCH_BOUNDS 1 + +#ifdef CK_USE_LAUNCH_BOUNDS +#define CK_MAX_THREAD_PER_BLOCK 256 +#define CK_MIN_BLOCK_PER_CU 2 +#endif + +// buffer resourse +#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \ + defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(CK_AMD_GPU_GFX1030) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#endif + +// FMA instruction +#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) +#define CK_USE_AMD_V_MAC_F32 +#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \ + defined(CK_AMD_GPU_GFX1030) +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8 +#endif + +// multi index +#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 + +// AMD inline asm +#ifndef CK_USE_AMD_INLINE_ASM +#define CK_USE_AMD_INLINE_ASM 1 +#endif + +// AMD inner product (DLOP) +#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM +#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 +#endif + +// AMD buffer addressing +#ifndef CK_USE_AMD_BUFFER_ADDRESSING +#define CK_USE_AMD_BUFFER_ADDRESSING 1 +#endif + +// only gfx908 support native floating point atomic add +#ifndef CK_USE_AMD_BUFFER_ATOMIC_FADD +#define CK_USE_AMD_BUFFER_ATOMIC_FADD 0 +#endif + +// AMD XDLOPS +#ifndef CK_USE_AMD_XDLOPS +#define CK_USE_AMD_XDLOPS 0 +#endif + +// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) +#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 +#endif + +// experimental implementation +#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#endif + +#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 +#endif + +// pass tensor descriptor by value or void* +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 + +// merge transformation use magic number division +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 + +// hack: have underlying assumption that need to be satsified, otherwise it's a bug +// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be +// thread-invariant, otherwise it's a bug +// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" +#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE +#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 +#endif + +// workaround for compiler crash when compiling recursive lambda +#ifndef CK_WORKAROUND_SWDEV_275126 +#define CK_WORKAROUND_SWDEV_275126 1 +#endif + +// workaround for compiler crash when using buffer load/store for i8 +#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 +#endif + +// workaround for compiler crash when using buffer load/store for i8 +#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 +#endif + +namespace ck { + +enum InMemoryDataOperationEnum_t +{ + Set, + AtomicAdd +}; + +// index type +using index_t = int32_t; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/container_element_picker.hpp b/composable_kernel/include/utility/container_element_picker.hpp new file mode 100644 index 0000000000..54915125ac --- /dev/null +++ b/composable_kernel/include/utility/container_element_picker.hpp @@ -0,0 +1,155 @@ +#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP +#define CK_CONTAINER_ELEMENT_PICKER_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ContainerElementPicker +{ + using type = ContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ContainerElementPicker() = delete; + + __host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array} + { + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr auto& At(Number i) + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray(IP); + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + private: + Arr& mArray; +}; + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ConstantContainerElementPicker +{ + using type = ConstantContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ConstantContainerElementPicker() = delete; + + __host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array} + { + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + private: + const Arr& mArray; +}; + +template +__host__ __device__ constexpr auto operator+=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto operator-=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto pick_container_element(Arr& a, Picks) +{ + return ContainerElementPicker(a); +} + +template +__host__ __device__ constexpr auto pick_container_element(const Arr& a, Picks) +{ + return ConstantContainerElementPicker(a); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/container_helper.hpp b/composable_kernel/include/utility/container_helper.hpp new file mode 100644 index 0000000000..a7ed8ec059 --- /dev/null +++ b/composable_kernel/include/utility/container_helper.hpp @@ -0,0 +1,403 @@ +#ifndef CK_CONTAINER_HELPER_HPP +#define CK_CONTAINER_HELPER_HPP + +#include "sequence.hpp" +#include "sequence_helper.hpp" +#include "array.hpp" +#include "tuple.hpp" +#include "tuple_helper.hpp" +#include "statically_indexed_array.hpp" +#include "container_element_picker.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto container_push_back(const Array& a, const TData& x) +{ + Array r; + + static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + + r(Number{}) = x; + + return r; +} + +template +__host__ __device__ constexpr auto container_push_front(const Tuple& a, const T& x) +{ + return container_concat(make_tuple(x), a); +} + +template +__host__ __device__ constexpr auto container_push_back(const Tuple& a, const T& x) +{ + return container_concat(a, make_tuple(x)); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_new2old(const Array& old_array, Sequence /*new2old*/) +{ + static_assert(NSize == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_array(old_array[Number{}]...); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_old2new(const Array& old_array, Sequence old2new) +{ + return container_reorder_given_new2old( + old_array, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple& old_tuple, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_tuple(old_tuple[Number{}]...); +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple& old_tuple, + Sequence old2new) +{ + return container_reorder_given_new2old( + old_tuple, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence /* old_seq */, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return Sequence::At(Number{})...>{}; +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(Sequence old_seq, + Sequence /* old2new */) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + constexpr auto new2old = typename sequence_map_inverse>::type{}; + + return container_reorder_given_new2old(old_seq, new2old); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm-4.1 compiler would crash for recursive lambda +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto r_old) { + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + // recursively call f/fs + return fs(fs, i + Number{}, r_new); + } + else + { + return r_new; + } + }; + + // start recursion + return f(f, Number{}, init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reduce_impl( + const Container& x, Reduce reduce, ROld r_old, Number i, Number, Number) +{ + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + return container_reduce_impl( + x, reduce, r_new, i + Number{}, Number{}, Number{}); + } + else + { + return r_new; + } +} + +// rocm-4.1 compiler would crash for recursive lambda +// container reduce with initial value +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + if constexpr(IEnd > IBegin) + { + return container_reduce_impl( + x, reduce, init, Number{}, Number{}, Number{}); + } + else + { + return init; + } +} +#endif + +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + y(i) = r; + r = f(r, x[i]); + }); + + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Sequence& seq, Reduce f, Number) +{ + return reverse_exclusive_scan_sequence(seq, f, Number{}); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm4.1 compiler would crash with recursive lambda +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto y_old, auto r_old) { + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return fs(fs, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } + }; + + // start recursion + return f(f, Number{}, make_tuple(init), init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reverse_exclusive_scan_impl( + const Tuple& x, Reduce reduce, Number i, YOld y_old, ROld r_old) +{ + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + return container_reverse_exclusive_scan_impl( + x, reduce, Number{}, make_tuple(init), init); +} +#endif + +// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<> +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Tuple& x, Reduce f, TData init) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys) +{ + return container_concat(x, container_concat(ys...)); +} + +template +__host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) +{ + return unpack2( + [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); +} + +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); +} + +template +__host__ __device__ constexpr auto container_concat(const Container& x) +{ + return x; +} + +template +__host__ __device__ constexpr auto get_container_subset(const Array& arr, Sequence) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + return make_array(arr[Number{}]...); +} + +template +__host__ __device__ constexpr auto get_container_subset(const Tuple& tup, Sequence) +{ + static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size"); + + return make_tuple(tup[Number{}]...); +} + +template +__host__ __device__ constexpr void +set_container_subset(Array& y, Sequence picks, const Array& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr void +set_container_subset(Tuple& y, Sequence picks, const Tuple& x) +{ + static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr auto to_tuple_of_number(const Container&) +{ + static_assert(is_known_at_compile_time::value, "wrong!"); + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Container::At(i); + return Number{}; + }, + Container::Size()); +} + +template +__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) +{ + using Seq = Sequence; + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Seq::At(i); + return Number{}; + }, + Seq::Size()); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp new file mode 100644 index 0000000000..24a2190e84 --- /dev/null +++ b/composable_kernel/include/utility/data_type.hpp @@ -0,0 +1,1017 @@ +#ifndef CK_FLOAT_TYPE_AMD_HPP +#define CK_FLOAT_TYPE_AMD_HPP + +#include "statically_indexed_array.hpp" + +namespace ck { + +using half_t = _Float16; + +// vector_type +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + +// scalar_type +template +struct scalar_type; + +template +struct scalar_type +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +// +template <> +struct scalar_type +{ + using type = float; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = half_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = ushort; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int32_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int8_t; + static constexpr index_t vector_size = 1; +}; + +// +template +struct vector_type +{ + using d1_t = T; + using type = d1_t; + + union + { + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, "wrong!"); + + return data_.d1x1_; + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } +}; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; +using half64_t = typename vector_type::type; + +// bfp16 +using ushort2_t = typename vector_type::type; +using ushort4_t = typename vector_type::type; +using ushort8_t = typename vector_type::type; +using ushort16_t = typename vector_type::type; +using ushort32_t = typename vector_type::type; +using ushort64_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// data type conversion +template +struct type_convert +{ + template + __device__ T operator()(X x) const + { + return static_cast(x); + } +}; + +template <> +template <> +__device__ float type_convert::operator()(ushort x) const +{ + return bfloat16_to_float(x); +} + +template <> +template <> +__device__ ushort type_convert::operator()(float x) const +{ + return float_to_bfloat16(x); +} + +// TODO: deprecate this +template +struct inner_product_with_conversion +{ + static constexpr auto convert = type_convert(); + + template + __device__ T operator()(typename vector_type::type a, + typename vector_type::type b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, N, 1>{}([&](auto i) { + acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + }); + + return acc; + } + + __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } + + __device__ T operator()(int8x4_t a, int8x4_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 4, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } + + __device__ T operator()(int8x8_t a, int8x8_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 8, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } + + __device__ T operator()(int8x16_t a, int8x16_t b) const + { + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + T acc = 0; + + static_for<0, 16, 1>{}([&](auto i) { + acc += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); + + return acc; + } +}; + +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Min() + { + return std::numeric_limits::min(); + } + + __host__ __device__ static constexpr int32_t Max() + { + return std::numeric_limits::max(); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/data_type_enum.hpp b/composable_kernel/include/utility/data_type_enum.hpp new file mode 100644 index 0000000000..35df0067a9 --- /dev/null +++ b/composable_kernel/include/utility/data_type_enum.hpp @@ -0,0 +1,19 @@ +#ifndef CK_DATA_TYPE_ENUM_HPP +#define CK_DATA_TYPE_ENUM_HPP + +namespace ck { + +enum DataTypeEnum_t +{ + Half = 0, + Float = 1, + Int32 = 2, + Int8 = 3, + Int8x4 = 4, + BFloat16 = 5, + Double = 6, + Unknown = 100, +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/data_type_enum_helper.hpp b/composable_kernel/include/utility/data_type_enum_helper.hpp new file mode 100644 index 0000000000..451ce992b1 --- /dev/null +++ b/composable_kernel/include/utility/data_type_enum_helper.hpp @@ -0,0 +1,76 @@ +#ifndef CK_DATA_TYPE_ENUM_HELPER_HPP +#define CK_DATA_TYPE_ENUM_HELPER_HPP + +#include "data_type.hpp" +#include "data_type_enum.hpp" + +namespace ck { + +template +struct get_datatype_from_enum; + +template <> +struct get_datatype_from_enum +{ + using type = int8_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = int32_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = half_t; +}; + +template <> +struct get_datatype_from_enum +{ + using type = float; +}; + +template <> +struct get_datatype_from_enum +{ + using type = double; +}; + +template +struct get_datatype_enum_from_type; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float; +}; + +template <> +struct get_datatype_enum_from_type +{ + static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp new file mode 100644 index 0000000000..4d583e3ce7 --- /dev/null +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -0,0 +1,246 @@ +#ifndef CK_BUFFER_HPP +#define CK_BUFFER_HPP + +#include "amd_buffer_addressing.hpp" +#include "c_style_pointer_cast.hpp" +#include "enable_if.hpp" + +namespace ck { + +template +struct DynamicBuffer +{ + using type = T; + + T* p_data_; + ElementSpaceSize element_space_size_; + T invalid_element_value_ = T{0}; + + __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) + : p_data_{p_data}, element_space_size_{element_space_size} + { + } + + __host__ __device__ constexpr DynamicBuffer(T* p_data, + ElementSpaceSize element_space_size, + T invalid_element_value) + : p_data_{p_data}, + element_space_size_{element_space_size}, + invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() + { + return BufferAddressSpace; + } + + template >>::type, + typename scalar_type>>::type>::value, + bool>::type = false> + __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = + scalar_type>>::vector_size; + + constexpr index_t scalar_per_x_vector = + scalar_type>>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X need to be multiple T"); + +#if CK_USE_AMD_BUFFER_ADDRESSING + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_return_zero< + remove_cv_t>, + t_per_x>(p_data_, i, is_valid_element, element_space_size_); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value< + remove_cv_t>, + t_per_x>( + p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); + } + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) : X{0}; + } + else + { + return is_valid_element ? *c_style_pointer_cast(&p_data_[i]) + : X{invalid_element_value_}; + } + } + } + + template >>::type, + typename scalar_type>>::type>::value, + bool>::type = false> + __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = + scalar_type>>::vector_size; + + constexpr index_t scalar_per_x_vector = + scalar_type>>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X need to be multiple T"); + + if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global) + { +#if CK_USE_AMD_BUFFER_ADDRESSING + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store>, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); +#else + if(is_valid_element) + { + *c_style_pointer_cast(&p_data_[i]) = x; + } +#endif + } + else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds) + { + if(is_valid_element) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + *c_style_pointer_cast(&p_data_[i]) = x; +#else + // HACK: compiler would lower IR "store address_space(3)" into + // inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + if constexpr(is_same>>::type, + int8_t>::value) + { + static_assert( + (is_same>, int8_t>::value && + is_same>, int8_t>::value) || + (is_same>, int8_t>::value && + is_same>, int8x2_t>::value) || + (is_same>, int8_t>::value && + is_same>, int8x4_t>::value) || + (is_same>, int8x4_t>::value && + is_same>, int8x4_t>::value) || + (is_same>, int8x8_t>::value && + is_same>, int8x8_t>::value) || + (is_same>, int8x16_t>::value && + is_same>, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(is_same>, int8_t>::value && + is_same>, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same>, int8_t>::value && + is_same>, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same>, int8_t>::value && + is_same>, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same>, + int8x4_t>::value && + is_same>, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same>, + int8x8_t>::value && + is_same>, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same>, + int8x16_t>::value && + is_same>, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + } + else + { + *c_style_pointer_cast(&p_data_[i]) = x; + } +#endif + } + } + else + { + if(is_valid_element) + { + *c_style_pointer_cast(&p_data_[i]) = x; + } + } + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) +{ + return DynamicBuffer{p, element_space_size}; +} + +template +__host__ __device__ constexpr auto +make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value) +{ + return DynamicBuffer{ + p, element_space_size, invalid_element_value}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/enable_if.hpp b/composable_kernel/include/utility/enable_if.hpp new file mode 100644 index 0000000000..501e1bfc1c --- /dev/null +++ b/composable_kernel/include/utility/enable_if.hpp @@ -0,0 +1,13 @@ +#ifndef CK_ENABLE_IF_HPP +#define CK_ENABLE_IF_HPP + +namespace ck { + +template +using enable_if = std::enable_if; + +template +using enable_if_t = typename std::enable_if::type; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/functional.hpp b/composable_kernel/include/utility/functional.hpp new file mode 100644 index 0000000000..b84b617f44 --- /dev/null +++ b/composable_kernel/include/utility/functional.hpp @@ -0,0 +1,116 @@ +#ifndef CK_FUNCTIONAL_HPP +#define CK_FUNCTIONAL_HPP + +#include "integral_constant.hpp" +#include "type.hpp" + +namespace ck { + +// TODO: right? wrong? +struct forwarder +{ + template + __host__ __device__ constexpr T&& operator()(T&& x) const + { + return static_cast(x); + } +}; + +struct swallow +{ + template + __host__ __device__ constexpr swallow(Ts&&...) + { + } +}; + +template +struct logical_and +{ + constexpr bool operator()(const T& x, const T& y) const { return x && y; } +}; + +template +struct logical_or +{ + constexpr bool operator()(const T& x, const T& y) const { return x || y; } +}; + +template +struct logical_not +{ + constexpr bool operator()(const T& x) const { return !x; } +}; + +// Emulate if constexpr +template +struct static_if; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F f) const + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being + // instantiated here + f(forwarder{}); + return Type{}; + } + + template + __host__ __device__ static void Else(F) + { + } +}; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F) const + { + return Type{}; + } + + template + __host__ __device__ static void Else(F f) + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being + // instantiated here + f(forwarder{}); + } +}; + +template +struct conditional; + +template +struct conditional +{ + using type = X; +}; + +template +struct conditional +{ + using type = Y; +}; + +template +using conditional_t = typename conditional::type; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp new file mode 100644 index 0000000000..371182a05e --- /dev/null +++ b/composable_kernel/include/utility/functional2.hpp @@ -0,0 +1,48 @@ +#ifndef CK_FUNCTIONAL2_HPP +#define CK_FUNCTIONAL2_HPP + +#include "functional.hpp" +#include "sequence.hpp" + +namespace ck { + +namespace detail { + +template +struct static_for_impl; + +template +struct static_for_impl> +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + swallow{(f(Number{}), 0)...}; + } +}; + +} // namespace detail + +// F signature: F(Number) +template +struct static_for +{ + __host__ __device__ constexpr static_for() + { + static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd), + "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && " + "NBegin >= NEnd)"); + } + + template + __host__ __device__ constexpr void operator()(F f) const + { + detail::static_for_impl::type>{}( + f); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/functional3.hpp b/composable_kernel/include/utility/functional3.hpp new file mode 100644 index 0000000000..6a400f3ca6 --- /dev/null +++ b/composable_kernel/include/utility/functional3.hpp @@ -0,0 +1,142 @@ +#ifndef CK_FUNCTIONAL3_HPP +#define CK_FUNCTIONAL3_HPP + +#include "functional.hpp" +#include "functional2.hpp" +#include "sequence.hpp" +#include "multi_index.hpp" + +namespace ck { + +namespace detail { + +// RemainLengths: Sequence<...> +// Orders: Sequence<...> +template +struct static_ford_impl +{ + __host__ __device__ constexpr static_ford_impl() + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + } + + // F signature: F(Sequence<...>) + // CurrentOrderedId: Sequence<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const + { + static_for<0, RemainLengths::Front(), 1>{}([=](auto I) { + static_ford_impl{}( + f, CurrentOrderedId::PushBack(I)); + }); + } +}; + +template +struct static_ford_impl, Orders> +{ + // F signature: F(Sequence<...>) + // OrderedId: Sequence<...> + template + __host__ __device__ constexpr void operator()(F f, OrderedId) const + { + // retrive unordered Id + f(OrderedId::ReorderGivenOld2New(Orders{})); + } +}; + +// RemainLengths: Sequence<...> +// Orders: Sequence<...> +template +struct ford_impl +{ + __host__ __device__ constexpr ford_impl() + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + } + + // F signature: F(Array<...> multi_id) + // CurrentOrderdId: Array<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const + { + for(index_t i = 0; i < RemainLengths::Front(); ++i) + { + ford_impl{}( + f, container_push_back(current_ordered_id, i)); + } + } +}; + +template +struct ford_impl, Orders> +{ + // F signature: F(Array<...> multi_id) + // CurrentOrderdId: Array<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const + { + // retrive unordered Id + f(container_reorder_given_old2new(current_ordered_id, Orders{})); + } +}; + +} // namespace detail + +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which static_ford +// will loop over each +// dimension +template ::type> +struct static_ford +{ + __host__ __device__ constexpr static_ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Sequence<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + detail::static_ford_impl{}(f, Sequence<>{}); + } +}; + +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which ford will loop +// over each +// dimension +template ::type> +struct ford +{ + __host__ __device__ constexpr ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Array<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + + for(index_t i = 0; i < ordered_lengths.Front(); ++i) + { + detail::ford_impl{}(f, + make_multi_index(i)); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/functional4.hpp b/composable_kernel/include/utility/functional4.hpp new file mode 100644 index 0000000000..b039644380 --- /dev/null +++ b/composable_kernel/include/utility/functional4.hpp @@ -0,0 +1,62 @@ +#ifndef CK_FUNCTIONAL4_HPP +#define CK_FUNCTIONAL4_HPP + +#include "sequence.hpp" +#include "tuple.hpp" +#include "array.hpp" + +namespace ck { + +namespace detail { + +template +struct unpack_impl; + +template +struct unpack_impl> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x) const + { + return std::forward(f)(std::forward(x).At(Number{})...); + } +}; + +template +struct unpack2_impl; + +// TODO: remove this, after properly implementing unpack that takes any number of containers +template +struct unpack2_impl, Sequence> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const + { + return std::forward(f)(std::forward(x).At(Number{})..., + std::forward(y).At(Number{})...); + } +}; + +} // namespace detail + +template +__host__ __device__ constexpr auto unpack(F&& f, X&& x) +{ + using X_ = remove_reference_t; + return detail::unpack_impl::type>{}( + std::forward(f), std::forward(x)); +} + +// TODO: properly implement unpack that takes any number of containers +template +__host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) +{ + using X_ = remove_reference_t; + using Y_ = remove_reference_t; + return detail::unpack2_impl::type, + typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( + std::forward(f), std::forward(x), std::forward(y)); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/inner_product.hpp b/composable_kernel/include/utility/inner_product.hpp new file mode 100644 index 0000000000..51753accf3 --- /dev/null +++ b/composable_kernel/include/utility/inner_product.hpp @@ -0,0 +1,207 @@ +#ifndef CK_INNER_PRODUCT_HPP +#define CK_INNER_PRODUCT_HPP + +#include "data_type.hpp" + +namespace ck { + +template +__device__ void inner_product(const TA& a, const TB& b, TC& c); + +template <> +__device__ void inner_product(const float& a, const float& b, float& c) +{ +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) + asm volatile("\n \ + v_mac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) + asm volatile("\n \ + v_fmac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c += a * b; +#endif +} + +template <> +__device__ void +inner_product(const float2_t& a, const float2_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const float4_t& a, const float4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void inner_product(const half2_t& a, const half2_t& b, float& c) +{ +#if defined(CK_USE_AMD_V_DOT2_F32_F16) +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM + asm volatile("\n \ + v_dot2_f32_f16 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c = __builtin_amdgcn_sdot2(a, b, c, false); +#endif +#else + const auto convert = type_convert{}; + + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 2, 1>{}([&](auto i) { + c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); +#endif +} + +template <> +__device__ void inner_product(const half4_t& a, const half4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void inner_product(const half8_t& a, const half8_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void +inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) +{ +#if defined(CK_USE_DOT4_I32_I8) +#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM + asm volatile("\n \ + v_dot4_i32_i8 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); +#else + c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); +#endif +#else + const auto convert = type_convert{}; + + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 4, 1>{}([&](auto i) { + c += convert(a_vector.AsType()[i]) * convert(b_vector.AsType()[i]); + }); +#endif +} + +template <> +__device__ void +inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/integral_constant.hpp b/composable_kernel/include/utility/integral_constant.hpp new file mode 100644 index 0000000000..14f3df894b --- /dev/null +++ b/composable_kernel/include/utility/integral_constant.hpp @@ -0,0 +1,17 @@ +#ifndef CK_INTEGRAL_CONSTANT_HPP +#define CK_INTEGRAL_CONSTANT_HPP + +namespace ck { + +template +struct integral_constant +{ + static constexpr T value = v; + typedef T value_type; + typedef integral_constant type; + __host__ __device__ constexpr operator value_type() const noexcept { return value; } + __host__ __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/magic_division.hpp b/composable_kernel/include/utility/magic_division.hpp new file mode 100644 index 0000000000..b7489016e9 --- /dev/null +++ b/composable_kernel/include/utility/magic_division.hpp @@ -0,0 +1,155 @@ +#ifndef CK_MAGIC_DIVISION_HPP +#define CK_MAGIC_DIVISION_HPP + +#include "config.hpp" +#include "integral_constant.hpp" +#include "number.hpp" +#include "type.hpp" +#include "tuple.hpp" + +namespace ck { + +// magic number division +// Caution: +// 1. For uint32_t as dividend: magic number division implementation being used would produce +// correct result if the dividend is uint32_t and its value is within 31-bit value range. +// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been +// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number +// division implementation for uint32_t is then used. Therefore, dividend value need to be +// non-negative. +// TODO: +// 1. Implement magic number divison for int32_t +// 2. Implement magic number divison for unit32_t with 32-bit value range +struct MagicDivision +{ + // uint32_t + __host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor) + { + // assert(divisior >= 1 && divisior <= INT32_MAX); + uint32_t shift = 0; + for(shift = 0; shift < 32; ++shift) + { + if((1U << shift) >= divisor) + { + break; + } + } + + uint64_t one = 1; + uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1; + // assert(multiplier <= 0xffffffffUL); + + return make_tuple(uint32_t(multiplier), shift); + } + + __host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor) + { + auto tmp = CalculateMagicNumbers(divisor); + + return tmp[Number<0>{}]; + } + + __host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor) + { + auto tmp = CalculateMagicNumbers(divisor); + + return tmp[Number<1>{}]; + } + + // integral_constant + template + __host__ __device__ static constexpr auto + CalculateMagicNumbers(integral_constant) + { + constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor}); + + constexpr uint32_t multiplier = tmp[Number<0>{}]; + constexpr uint32_t shift = tmp[Number<1>{}]; + + return make_tuple(integral_constant{}, + integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicMultiplier(integral_constant) + { + constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor}); + + return integral_constant{}; + } + + template + __host__ __device__ static constexpr auto + CalculateMagicShift(integral_constant) + { + constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor}); + + return integral_constant{}; + } + + // integral_constant + template + __host__ __device__ static constexpr auto + CalculateMagicNumbers(integral_constant) + { + return CalculateMagicNumbers(integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicMultiplier(integral_constant) + { + return CalculateMagicMultiplier(integral_constant{}); + } + + template + __host__ __device__ static constexpr auto + CalculateMagicShift(integral_constant) + { + return CalculateMagicShift(integral_constant{}); + } + + // magic division for uint32_t + __host__ __device__ static constexpr uint32_t + DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32; + return (tmp + dividend) >> shift; + } + +#if 1 // debug + // HACK: magic division for int32_t + // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be + // non-negative for result to be correct + // TODO: figure out how to do magic number divison for int32_t as dividended + __host__ __device__ static constexpr int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = as_type(dividend_i32); + uint32_t tmp = + (static_cast(dividend_u32) * static_cast(multiplier)) >> 32; + return (tmp + dividend_u32) >> shift; + } +#else + // the inline ASM is producing wrong result + __host__ __device__ static int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t r; + asm volatile("\n \ + v_mul_hi_u32 %0, %1, %2 \n \ + v_add_u32_e32 %0, %1, %0 \n \ + v_lshrrev_b32_e32 %0, %3, %0 \n \ + " + : "=v"(r) + : "v"(as_type(dividend_i32)), "s"(multiplier), "s"(shift)); + + return as_type(r); + } +#endif +}; + +} // namespace ck + +#endif diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp new file mode 100644 index 0000000000..48438e6179 --- /dev/null +++ b/composable_kernel/include/utility/math.hpp @@ -0,0 +1,216 @@ +#ifndef CK_MATH_HPP +#define CK_MATH_HPP + +#include "config.hpp" +#include "integral_constant.hpp" +#include "number.hpp" +#include "type.hpp" +#include "enable_if.hpp" + +namespace ck { +namespace math { + +template +struct scales +{ + __host__ __device__ constexpr T operator()(T a) const { return s * a; } +}; + +template +struct plus +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; } +}; + +template +struct minus +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; } +}; + +struct multiplies +{ + template + __host__ __device__ constexpr auto operator()(const A& a, const B& b) const + { + return a * b; + } +}; + +template +struct maximize +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } +}; + +template +struct minimize +{ + __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; } +}; + +template +struct integer_divide_ceiler +{ + __host__ __device__ constexpr T operator()(T a, T b) const + { + static_assert(is_same{} || is_same{}, "wrong type"); + + return (a + b - Number<1>{}) / b; + } +}; + +template +__host__ __device__ constexpr auto integer_divide_floor(X x, Y y) +{ + return x / y; +} + +template +__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) +{ + return (x + y - Number<1>{}) / y; +} + +template +__host__ __device__ constexpr auto integer_least_multiple(X x, Y y) +{ + return y * integer_divide_ceil(x, y); +} + +template +__host__ __device__ constexpr T max(T x) +{ + return x; +} + +template +__host__ __device__ constexpr T max(T x, T y) +{ + return x > y ? x : y; +} + +template +__host__ __device__ constexpr index_t max(Number, index_t y) +{ + return X > y ? X : y; +} + +template +__host__ __device__ constexpr index_t max(index_t x, Number) +{ + return x > Y ? x : Y; +} + +template +__host__ __device__ constexpr auto max(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + + return max(x, max(ys...)); +} + +template +__host__ __device__ constexpr T min(T x) +{ + return x; +} + +template +__host__ __device__ constexpr T min(T x, T y) +{ + return x < y ? x : y; +} + +template +__host__ __device__ constexpr index_t min(Number, index_t y) +{ + return X < y ? X : y; +} + +template +__host__ __device__ constexpr index_t min(index_t x, Number) +{ + return x < Y ? x : Y; +} + +template +__host__ __device__ constexpr auto min(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + + return min(x, min(ys...)); +} + +// greatest common divisor, aka highest common factor +__host__ __device__ constexpr index_t gcd(index_t x, index_t y) +{ + if(x < 0) + { + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x % y, y); + } + else + { + return gcd(x, y % x); + } +} + +template +__host__ __device__ constexpr auto gcd(Number, Number) +{ + constexpr auto r = gcd(X, Y); + + return Number{}; +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto gcd(X x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} + +// least common multiple +template +__host__ __device__ constexpr auto lcm(X x, Y y) +{ + return (x * y) / gcd(x, y); +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto lcm(X x, Ys... ys) +{ + return lcm(x, lcm(ys...)); +} + +template +struct equal +{ + __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; } +}; + +template +struct less +{ + __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } +}; + +} // namespace math +} // namespace ck + +#endif diff --git a/composable_kernel/include/utility/multi_index.hpp b/composable_kernel/include/utility/multi_index.hpp new file mode 100644 index 0000000000..0bb34fb1e2 --- /dev/null +++ b/composable_kernel/include/utility/multi_index.hpp @@ -0,0 +1,12 @@ +#ifndef CK_MULTI_INDEX_HPP +#define CK_MULTI_INDEX_HPP + +#include "common_header.hpp" + +#if CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX +#include "array_multi_index.hpp" +#else +#include "statically_indexed_array_multi_index.hpp" +#endif + +#endif diff --git a/composable_kernel/include/utility/number.hpp b/composable_kernel/include/utility/number.hpp new file mode 100644 index 0000000000..f8c5643694 --- /dev/null +++ b/composable_kernel/include/utility/number.hpp @@ -0,0 +1,44 @@ +#ifndef CK_NUMBER_HPP +#define CK_NUMBER_HPP + +#include "integral_constant.hpp" + +namespace ck { + +template +using Number = integral_constant; + +template +__host__ __device__ constexpr auto operator+(Number, Number) +{ + return Number{}; +} + +template +__host__ __device__ constexpr auto operator-(Number, Number) +{ + static_assert(Y <= X, "wrong!"); + return Number{}; +} + +template +__host__ __device__ constexpr auto operator*(Number, Number) +{ + return Number{}; +} + +template +__host__ __device__ constexpr auto operator/(Number, Number) +{ + static_assert(Y > 0, "wrong!"); + return Number{}; +} + +template +__host__ __device__ constexpr auto operator%(Number, Number) +{ + static_assert(Y > 0, "wrong!"); + return Number{}; +} +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/print.hpp b/composable_kernel/include/utility/print.hpp new file mode 100644 index 0000000000..d7d58bbb83 --- /dev/null +++ b/composable_kernel/include/utility/print.hpp @@ -0,0 +1,22 @@ +#ifndef CK_PRINT_HPP +#define CK_PRINT_HPP + +#include "array.hpp" +#include "statically_indexed_array.hpp" +#include "container_helper.hpp" +#include "sequence.hpp" + +namespace ck { + +template +__host__ __device__ void print_array(const char* s, T a) +{ + constexpr index_t nsize = a.Size(); + + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); + printf("}\n"); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/sequence.hpp b/composable_kernel/include/utility/sequence.hpp new file mode 100644 index 0000000000..b35999d56f --- /dev/null +++ b/composable_kernel/include/utility/sequence.hpp @@ -0,0 +1,880 @@ +#ifndef CK_SEQUENCE_HPP +#define CK_SEQUENCE_HPP + +#include "integral_constant.hpp" +#include "type.hpp" +#include "functional.hpp" +#include "math.hpp" + +namespace ck { + +template +struct static_for; + +template +struct Sequence; + +template +struct sequence_split; + +template +struct sequence_reverse; + +template +struct sequence_map_inverse; + +template +struct is_valid_sequence_map; + +template +__host__ __device__ constexpr auto sequence_pop_front(Sequence); + +template +__host__ __device__ constexpr auto sequence_pop_back(Seq); + +template +struct Sequence +{ + using Type = Sequence; + using data_type = index_t; + + static constexpr index_t mSize = sizeof...(Is); + + __host__ __device__ static constexpr auto Size() { return Number{}; } + + __host__ __device__ static constexpr auto GetSize() { return Size(); } + + __host__ __device__ static constexpr index_t At(index_t I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const index_t mData[mSize + 1] = {Is..., 0}; + return mData[I]; + } + + template + __host__ __device__ static constexpr auto At(Number) + { + static_assert(I < mSize, "wrong! I too large"); + + return Number{}; + } + + template + __host__ __device__ static constexpr auto Get(Number) + { + return At(Number{}); + } + + template + __host__ __device__ constexpr auto operator[](I i) const + { + return At(i); + } + + template + __host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence /*new2old*/) + { + static_assert(sizeof...(Is) == sizeof...(IRs), + "wrong! reorder map should have the same size as Sequence to be rerodered"); + + static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); + + return Sequence{})...>{}; + } + + // MapOld2New is Sequence<...> + template + __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) + { + static_assert(MapOld2New::Size() == Size(), + "wrong! reorder map should have the same size as Sequence to be rerodered"); + + static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); + + return ReorderGivenNew2Old(typename sequence_map_inverse::type{}); + } + + __host__ __device__ static constexpr auto Reverse() + { + return typename sequence_reverse::type{}; + } + + __host__ __device__ static constexpr auto Front() + { + static_assert(mSize > 0, "wrong!"); + return At(Number<0>{}); + } + + __host__ __device__ static constexpr auto Back() + { + static_assert(mSize > 0, "wrong!"); + return At(Number{}); + } + + __host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); } + + __host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); } + + template + __host__ __device__ static constexpr auto PushFront(Sequence) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushFront(Number...) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushBack(Sequence) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto PushBack(Number...) + { + return Sequence{}; + } + + template + __host__ __device__ static constexpr auto Extract(Number...) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto Extract(Sequence) + { + return Sequence{})...>{}; + } + + template + __host__ __device__ static constexpr auto Modify(Number, Number) + { + static_assert(I < Size(), "wrong!"); + + using seq_split = sequence_split; + constexpr auto seq_left = typename seq_split::left_type{}; + constexpr auto seq_right = typename seq_split::right_type{}.PopFront(); + + return seq_left.PushBack(Number{}).PushBack(seq_right); + } + + template + __host__ __device__ static constexpr auto Transform(F f) + { + return Sequence{}; + } + + __host__ __device__ static void Print() + { + printf("{"); + printf("size %d, ", index_t{Size()}); + static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); }); + printf("}"); + } +}; + +// merge sequence +template +struct sequence_merge +{ + using type = typename sequence_merge::type>::type; +}; + +template +struct sequence_merge, Sequence> +{ + using type = Sequence; +}; + +template +struct sequence_merge +{ + using type = Seq; +}; + +// generate sequence +template +struct sequence_gen +{ + template + struct sequence_gen_impl + { + static constexpr index_t NRemainLeft = NRemain / 2; + static constexpr index_t NRemainRight = NRemain - NRemainLeft; + static constexpr index_t IMiddle = IBegin + NRemainLeft; + + using type = typename sequence_merge< + typename sequence_gen_impl::type, + typename sequence_gen_impl::type>::type; + }; + + template + struct sequence_gen_impl + { + static constexpr index_t Is = G{}(Number{}); + using type = Sequence; + }; + + template + struct sequence_gen_impl + { + using type = Sequence<>; + }; + + using type = typename sequence_gen_impl<0, NSize, F>::type; +}; + +// arithmetic sequence +template +struct arithmetic_sequence_gen +{ + struct F + { + __host__ __device__ constexpr index_t operator()(index_t i) const + { + return i * Increment + IBegin; + } + }; + + using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; +}; + +// uniform sequence +template +struct uniform_sequence_gen +{ + struct F + { + __host__ __device__ constexpr index_t operator()(index_t) const { return I; } + }; + + using type = typename sequence_gen::type; +}; + +// reverse inclusive scan (with init) sequence +template +struct sequence_reverse_inclusive_scan; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; + + static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front()); + + using type = typename sequence_merge, old_scan>::type; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = Sequence; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = Sequence<>; +}; + +// split sequence +template +struct sequence_split +{ + static constexpr index_t NSize = Seq{}.Size(); + + using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; + using range1 = typename arithmetic_sequence_gen::type; + + using left_type = decltype(Seq::Extract(range0{})); + using right_type = decltype(Seq::Extract(range1{})); +}; + +// reverse sequence +template +struct sequence_reverse +{ + static constexpr index_t NSize = Seq{}.Size(); + + using seq_split = sequence_split; + using type = typename sequence_merge< + typename sequence_reverse::type, + typename sequence_reverse::type>::type; +}; + +template +struct sequence_reverse> +{ + using type = Sequence; +}; + +template +struct sequence_reverse> +{ + using type = Sequence; +}; + +#if 1 +template +struct sequence_reduce +{ + using type = typename sequence_reduce::type>::type; +}; + +template +struct sequence_reduce, Sequence> +{ + using type = Sequence; +}; + +template +struct sequence_reduce +{ + using type = Seq; +}; +#endif + +template +struct sequence_sort_impl +{ + template + struct sorted_sequence_merge_impl + { + static constexpr bool choose_left = LeftValues::Front() < RightValues::Front(); + + static constexpr index_t chosen_value = + choose_left ? LeftValues::Front() : RightValues::Front(); + static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front(); + + using new_merged_values = decltype(MergedValues::PushBack(Number{})); + using new_merged_ids = decltype(MergedIds::PushBack(Number{})); + + using new_left_values = + typename conditional::type; + using new_left_ids = + typename conditional::type; + + using new_right_values = + typename conditional::type; + using new_right_ids = + typename conditional::type; + + using merge = sorted_sequence_merge_impl; + // this is output + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + template + struct sorted_sequence_merge_impl, + Sequence<>, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge_impl, + Sequence<>, + RightValues, + RightIds, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge + { + using merge = sorted_sequence_merge_impl, + Sequence<>, + Comp>; + + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + static constexpr index_t nsize = Values::Size(); + + using split_unsorted_values = sequence_split; + using split_unsorted_ids = sequence_split; + + using left_unsorted_values = typename split_unsorted_values::left_type; + using left_unsorted_ids = typename split_unsorted_ids::left_type; + using left_sort = sequence_sort_impl; + using left_sorted_values = typename left_sort::sorted_values; + using left_sorted_ids = typename left_sort::sorted_ids; + + using right_unsorted_values = typename split_unsorted_values::right_type; + using right_unsorted_ids = typename split_unsorted_ids::right_type; + using right_sort = sequence_sort_impl; + using right_sorted_values = typename right_sort::sorted_values; + using right_sorted_ids = typename right_sort::sorted_ids; + + using merged_sorted = sorted_sequence_merge; + + using sorted_values = typename merged_sorted::merged_values; + using sorted_ids = typename merged_sorted::merged_ids; +}; + +template +struct sequence_sort_impl, Sequence, Compare> +{ + static constexpr bool choose_x = Compare{}(ValueX, ValueY); + + using sorted_values = + typename conditional, Sequence>::type; + using sorted_ids = typename conditional, Sequence>::type; +}; + +template +struct sequence_sort_impl, Sequence, Compare> +{ + using sorted_values = Sequence; + using sorted_ids = Sequence; +}; + +template +struct sequence_sort_impl, Sequence<>, Compare> +{ + using sorted_values = Sequence<>; + using sorted_ids = Sequence<>; +}; + +template +struct sequence_sort +{ + using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type; + using sort = sequence_sort_impl; + + // this is output + using type = typename sort::sorted_values; + using sorted2unsorted_map = typename sort::sorted_ids; +}; + +template +struct sequence_unique_sort +{ + template + struct sorted_sequence_uniquify_impl + { + static constexpr index_t current_value = RemainValues::Front(); + static constexpr index_t current_id = RemainIds::Front(); + + static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back()); + + using new_remain_values = decltype(RemainValues::PopFront()); + using new_remain_ids = decltype(RemainIds::PopFront()); + + using new_uniquified_values = + typename conditional{})), + UniquifiedValues>::type; + + using new_uniquified_ids = + typename conditional{})), + UniquifiedIds>::type; + + using uniquify = sorted_sequence_uniquify_impl; + + // this is output + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + template + struct sorted_sequence_uniquify_impl, + Sequence<>, + UniquifiedValues, + UniquifiedIds, + Eq> + { + using uniquified_values = UniquifiedValues; + using uniquified_ids = UniquifiedIds; + }; + + template + struct sorted_sequence_uniquify + { + using uniquify = sorted_sequence_uniquify_impl, + Sequence, + Eq>; + + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + using sort = sequence_sort; + using sorted_values = typename sort::type; + using sorted_ids = typename sort::sorted2unsorted_map; + + using uniquify = sorted_sequence_uniquify; + + // this is output + using type = typename uniquify::uniquified_values; + using sorted2unsorted_map = typename uniquify::uniquified_ids; +}; + +template +struct is_valid_sequence_map : is_same::type, + typename sequence_sort>::type> +{ +}; + +template +struct sequence_map_inverse +{ + template + struct sequence_map_inverse_impl + { + static constexpr auto new_y2x = + WorkingY2X::Modify(X2Y::At(Number{}), Number{}); + + using type = + typename sequence_map_inverse_impl:: + type; + }; + + template + struct sequence_map_inverse_impl + { + using type = WorkingY2X; + }; + + using type = + typename sequence_map_inverse_impl::type, + 0, + SeqMap::Size()>::type; +}; + +template +__host__ __device__ constexpr auto operator+(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs + Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs - Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs * Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs / Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Sequence, Sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return Sequence<(Xs % Ys)...>{}; +} + +template +__host__ __device__ constexpr auto operator+(Sequence, Number) +{ + return Sequence<(Xs + Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Sequence, Number) +{ + return Sequence<(Xs - Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Sequence, Number) +{ + return Sequence<(Xs * Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Sequence, Number) +{ + return Sequence<(Xs / Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Sequence, Number) +{ + return Sequence<(Xs % Y)...>{}; +} + +template +__host__ __device__ constexpr auto operator+(Number, Sequence) +{ + return Sequence<(Y + Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator-(Number, Sequence) +{ + return Sequence<(Y - Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator*(Number, Sequence) +{ + return Sequence<(Y * Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator/(Number, Sequence) +{ + return Sequence<(Y / Xs)...>{}; +} + +template +__host__ __device__ constexpr auto operator%(Number, Sequence) +{ + return Sequence<(Y % Xs)...>{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_front(Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_pop_back(Seq) +{ + static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!"); + return sequence_pop_front(Seq::Reverse()).Reverse(); +} + +template +__host__ __device__ constexpr auto merge_sequences(Seqs...) +{ + return typename sequence_merge::type{}; +} + +template +__host__ __device__ constexpr auto transform_sequences(F f, Sequence) +{ + return Sequence{}; +} + +template +__host__ __device__ constexpr auto transform_sequences(F f, Sequence, Sequence) +{ + static_assert(Sequence::mSize == Sequence::mSize, "Dim not the same"); + + return Sequence{}; +} + +template +__host__ __device__ constexpr auto +transform_sequences(F f, Sequence, Sequence, Sequence) +{ + static_assert(Sequence::mSize == Sequence::mSize && + Sequence::mSize == Sequence::mSize, + "Dim not the same"); + + return Sequence{}; +} + +template +__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number) +{ + return typename sequence_reverse_inclusive_scan::type{}; +} + +template +__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number) +{ + return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number{}) + .PushBack(Number{}); +} + +template +__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) +{ + return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number{}).Reverse(); +} + +template +__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence /* ids */) +{ + return Sequence{})...>{}; +} + +#if 1 +namespace detail { +template +struct pick_sequence_elements_by_mask_impl +{ + using new_work_seq = typename conditional::type; + + using type = + typename pick_sequence_elements_by_mask_impl::type; +}; + +template +struct pick_sequence_elements_by_mask_impl, Sequence<>> +{ + using type = WorkSeq; +}; + +} // namespace detail + +template +__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask) +{ + static_assert(Seq::Size() == Mask::Size(), "wrong!"); + + return typename detail::pick_sequence_elements_by_mask_impl, Seq, Mask>::type{}; +} + +namespace detail { +template +struct modify_sequence_elements_by_ids_impl +{ + using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front())); + + using type = + typename modify_sequence_elements_by_ids_impl::type; +}; + +template +struct modify_sequence_elements_by_ids_impl, Sequence<>> +{ + using type = WorkSeq; +}; +} // namespace detail + +template +__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids) +{ + static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!"); + + return typename detail::modify_sequence_elements_by_ids_impl::type{}; +} +#endif + +template +__host__ __device__ constexpr index_t +reduce_on_sequence(Seq, Reduce f, Number /*initial_value*/) +{ + index_t result = Init; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + result = f(result, Seq::At(i)); + } + + return result; +} + +// TODO: a generic any_of for any container +template +__host__ __device__ constexpr bool sequence_any_of(Seq, F f) +{ + bool flag = false; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + flag = flag || f(Seq::At(i)); + } + + return flag; +} + +// TODO: a generic all_of for any container +template +__host__ __device__ constexpr bool sequence_all_of(Seq, F f) +{ + bool flag = true; + + for(index_t i = 0; i < Seq::Size(); ++i) + { + flag = flag && f(Seq::At(i)); + } + + return flag; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/composable_kernel/include/utility/sequence_helper.hpp new file mode 100644 index 0000000000..88d7da63e8 --- /dev/null +++ b/composable_kernel/include/utility/sequence_helper.hpp @@ -0,0 +1,36 @@ +#ifndef CK_SEQUENCE_HELPER_HPP +#define CK_SEQUENCE_HELPER_HPP + +#include "tuple.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto make_sequence(Number...) +{ + return Sequence{}; +} + +// F returns index_t +template +__host__ __device__ constexpr auto generate_sequence(F, Number) +{ + return typename sequence_gen::type{}; +} + +// F returns Number<> +template +__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +template +__host__ __device__ constexpr auto to_sequence(Tuple...>) +{ + return Sequence{}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp new file mode 100644 index 0000000000..cd67b8a0be --- /dev/null +++ b/composable_kernel/include/utility/static_buffer.hpp @@ -0,0 +1,71 @@ +#ifndef CK_STATIC_BUFFER_HPP +#define CK_STATIC_BUFFER_HPP + +#include "statically_indexed_array.hpp" + +namespace ck { + +template +struct StaticBuffer : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + T invalid_element_value_ = T{0}; + + __host__ __device__ constexpr StaticBuffer() : base{} {} + + __host__ __device__ constexpr StaticBuffer(T invalid_element_value) + : base{}, invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() + { + return BufferAddressSpace; + } + + template + __host__ __device__ constexpr auto Get(Number i, bool is_valid_element) const + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return is_valid_element ? At(i) : T{0}; + } + else + { + return is_valid_element ? At(i) : invalid_element_value_; + } + } + + template + __host__ __device__ void Set(Number i, bool is_valid_element, const T& x) + { + if(is_valid_element) + { + At(i) = x; + } + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } +}; + +template +__host__ __device__ constexpr auto make_static_buffer(Number) +{ + return StaticBuffer{}; +} + +template +__host__ __device__ constexpr auto make_static_buffer(Number, T invalid_element_value) +{ + return StaticBuffer{invalid_element_value}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/statically_indexed_array.hpp b/composable_kernel/include/utility/statically_indexed_array.hpp new file mode 100644 index 0000000000..f30a3a9ee6 --- /dev/null +++ b/composable_kernel/include/utility/statically_indexed_array.hpp @@ -0,0 +1,40 @@ +#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP +#define CK_STATICALLY_INDEXED_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +namespace detail { + +template +__host__ __device__ constexpr auto generate_same_type_tuple() +{ + return generate_tuple([](auto) -> T { return T{}; }, Number{}); +} + +template +using same_type_tuple = decltype(generate_same_type_tuple()); + +} // namespace detail + +template +using StaticallyIndexedArray = detail::same_type_tuple; + +template +__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) +{ + return StaticallyIndexedArray(x, static_cast(xs)...); +} + +// make empty StaticallyIndexedArray +template +__host__ __device__ constexpr auto make_statically_indexed_array() +{ + return StaticallyIndexedArray(); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp b/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp new file mode 100644 index 0000000000..9e96f06d73 --- /dev/null +++ b/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp @@ -0,0 +1,108 @@ +#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP +#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = StaticallyIndexedArray; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_statically_indexed_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +// Here should use MultiIndex, instead of Tuple, although the former +// is the alias of the latter. This is because compiler cannot infer the NSize if +// using MultiIndex +// TODO: how to fix this? +template +__host__ __device__ constexpr auto operator+=(Tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator-=(Tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator+(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] + y[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator-(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] - y[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator*(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y[i]; }); + return r; +} + +// MultiIndex = index_t * MultiIndex +template +__host__ __device__ constexpr auto operator*(index_t a, const Tuple& x) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a * x[i]; }); + return r; +} + +template +__host__ __device__ void print_multi_index(const Tuple& x) +{ + printf("{"); + printf("MultiIndex, "); + printf("size %d,", index_t{sizeof...(Xs)}); + static_for<0, sizeof...(Xs), 1>{}( + [&](auto i) { printf("%d ", static_cast(x.At(i))); }); + printf("}"); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/synchronization.hpp b/composable_kernel/include/utility/synchronization.hpp new file mode 100644 index 0000000000..da74f2074d --- /dev/null +++ b/composable_kernel/include/utility/synchronization.hpp @@ -0,0 +1,21 @@ +#ifndef CK_SYNCHRONIZATION_AMD_HPP +#define CK_SYNCHRONIZATION_AMD_HPP + +#include "config.hpp" + +namespace ck { + +__device__ void block_sync_lds() +{ +#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +#else + __syncthreads(); +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp new file mode 100644 index 0000000000..ee96a8b435 --- /dev/null +++ b/composable_kernel/include/utility/tuple.hpp @@ -0,0 +1,166 @@ +#ifndef CK_TUPLE_HPP +#define CK_TUPLE_HPP + +#include "integral_constant.hpp" +#include "sequence.hpp" +#include "type.hpp" +#include "enable_if.hpp" + +namespace ck { + +namespace detail { + +template +struct TupleElementKey +{ + __host__ __device__ constexpr TupleElementKey() = default; +}; + +template +struct TupleElement +{ + __host__ __device__ constexpr TupleElement() = default; + + template >, TupleElement>::value, + bool>::type = false> + __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) + { + } + + Data mData; +}; + +template +__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement& x) +{ + return static_cast(x.mData); +} + +template +__host__ __device__ constexpr Data& get_tuple_element(TupleElement& x) +{ + return x.mData; +} + +// TODO: not sure the use of reference is correct +template +__host__ __device__ constexpr Data&& get_tuple_element(TupleElement&& x) +{ + return static_cast(x.mData); +} + +template +struct TupleImpl; + +template +struct TupleImpl, Xs...> : TupleElement, Xs>... +{ + __host__ __device__ constexpr TupleImpl() = default; + + template >, TupleImpl>::value, + bool>::type = false> + __host__ __device__ constexpr TupleImpl(Y&& y) + : TupleElement, Xs>(std::forward(y))... + { + } + + template = 2, bool>::type = false> + __host__ __device__ constexpr TupleImpl(Ys&&... ys) + : TupleElement, Xs>(std::forward(ys))... + { + static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), + "wrong! inconsistent size"); + } + + __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } + + template + __host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey) const + { + return get_tuple_element>(*this); + } + + template + __host__ __device__ constexpr auto& GetElementByKey(TupleElementKey) + { + return get_tuple_element>(*this); + } +}; + +} // namespace detail + +template +struct Tuple : detail::TupleImpl::type, Xs...> +{ + using base = + detail::TupleImpl::type, Xs...>; + + __host__ __device__ constexpr Tuple() = default; + + template >, Tuple>::value, + bool>::type = false> + __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) + { + } + + template = 2, bool>::type = + false> + __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward(ys)...) + { + } + + __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } + + template + __host__ __device__ constexpr const auto& At(Number) const + { + static_assert(I < base::Size(), "wrong! out of range"); + return base::GetElementByKey(detail::TupleElementKey{}); + } + + template + __host__ __device__ constexpr auto& At(Number) + { + static_assert(I < base::Size(), "wrong! out of range"); + return base::GetElementByKey(detail::TupleElementKey{}); + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_tuple(Xs&&... xs) +{ + return Tuple>...>(std::forward(xs)...); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/tuple_helper.hpp b/composable_kernel/include/utility/tuple_helper.hpp new file mode 100644 index 0000000000..9499a3596c --- /dev/null +++ b/composable_kernel/include/utility/tuple_helper.hpp @@ -0,0 +1,80 @@ +#ifndef CK_TUPLE_HELPER_HPP +#define CK_TUPLE_HELPER_HPP + +#include "functional4.hpp" +#include "tuple.hpp" + +namespace ck { + +template +struct is_known_at_compile_time> +{ + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return container_reduce( + Tuple{}, + [](auto x, bool r) { + return is_known_at_compile_time< + remove_cv_t>>::value & + r; + }, + true); + } + + static constexpr bool value = IsKnownAtCompileTime(); +}; + +template +__host__ __device__ constexpr auto generate_tuple(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +namespace detail { + +template +__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence) +{ + return make_tuple(f(x.At(Number{}))...); +} + +template +__host__ __device__ constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, Sequence) +{ + return make_tuple(f(x.At(Number{}), y.At(Number{}))...); +} + +template +__host__ __device__ constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence) +{ + return make_tuple(f(x.At(Number{}), y.At(Number{}), z.At(Number{}))...); +} + +} // namespace detail + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x) +{ + return detail::transform_tuples_impl( + f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y) +{ + return detail::transform_tuples_impl( + f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) +{ + return detail::transform_tuples_impl( + f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp new file mode 100644 index 0000000000..b7902ad496 --- /dev/null +++ b/composable_kernel/include/utility/type.hpp @@ -0,0 +1,56 @@ +#ifndef CK_TYPE_HPP +#define CK_TYPE_HPP + +#include "integral_constant.hpp" +#include "enable_if.hpp" + +namespace ck { + +template +struct is_same : public integral_constant +{ +}; + +template +struct is_same : public integral_constant +{ +}; + +template +using remove_reference_t = typename std::remove_reference::type; + +template +using remove_cv_t = typename std::remove_cv::type; + +template +inline constexpr bool is_pointer_v = std::is_pointer::value; + +template +struct is_known_at_compile_time; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + +template ::type = false> +__host__ __device__ constexpr Y as_type(X x) +{ + union AsType + { + X x; + Y y; + }; + + return AsType{x}.y; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp new file mode 100644 index 0000000000..9f34e044b7 --- /dev/null +++ b/composable_kernel/include/utility/utility.hpp @@ -0,0 +1,14 @@ +#ifndef CK_UTILITY_HPP +#define CK_UTILITY_HPP + +#include "config.hpp" + +namespace ck { + +__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } + +__device__ index_t get_block_1d_id() { return blockIdx.x; } + +} // namespace ck + +#endif diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..09a7fffa3e --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,370 @@ +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v1r2.hpp" +#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; +constexpr index_t M1PerThread = CK_PARAM_M1PerThread; +constexpr index_t N1PerThread = CK_PARAM_N1PerThread; +constexpr index_t KPerThread = CK_PARAM_KPerThread; +constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10; +constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10; +constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11; +constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11; + +using ABlockTransferThreadSliceLengths_K_M0_M1 = + Sequence; +using ABlockTransferThreadClusterLengths_K_M0_M1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_M1 = + CK_PARAM_ABlockTransferDstScalarPerVector_M1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K_N0_N1 = + Sequence; +using BBlockTransferThreadClusterLengths_K_N0_N1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_N1 = + CK_PARAM_BBlockTransferDstScalarPerVector_N1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); +constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); + +extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare( + int n, + int c, + int hi, + int wi, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k_m0_m1_grid_desc, + void* p_b_k_n0_n1_grid_desc, + void* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi)); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x)); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW)); + + const auto a_k_m_grid_desc = descs[I0]; + const auto b_k_n_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AKMGridDesc = decltype(a_k_m_grid_desc); + using BKNGridDesc = decltype(b_k_n_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}))); + + using BGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseGemmDlops_km_kn_mn_v1r2; + + auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc; + *static_cast(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc; + *static_cast( + p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + }; +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k_m0_m1_grid_desc, + const void CONSTANT* p_b_k_n0_n1_grid_desc, + const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1)); + + constexpr auto a_k_m_grid_desc = descs[I0]; + constexpr auto b_k_n_grid_desc = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AKMGridDesc = decltype(a_k_m_grid_desc); + using BKNGridDesc = decltype(b_k_n_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}))); + + using BGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseGemmDlops_km_kn_mn_v1r2; + + constexpr auto a_k_m0_m1_grid_desc_tmp = + GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + constexpr auto b_k_n0_n1_grid_desc_tmp = + GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp); + using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp); + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k_m0_m1_grid_desc = + *reinterpret_cast((const void*)p_a_k_m0_m1_grid_desc); + const auto b_k_n0_n1_grid_desc = + *reinterpret_cast((const void*)p_b_k_n0_n1_grid_desc); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..51d852617f --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,358 @@ +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; + +constexpr index_t MPerWave = CK_PARAM_MPerWave; +constexpr index_t NPerWave = CK_PARAM_NPerWave; +constexpr index_t MRepeat = CK_PARAM_MRepeat; +constexpr index_t NRepeat = CK_PARAM_NRepeat; +constexpr index_t K1 = CK_PARAM_K1; + +using ABlockTransferThreadSliceLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_K1 = + CK_PARAM_ABlockTransferDstScalarPerVector_K1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_K1 = + CK_PARAM_BBlockTransferDstScalarPerVector_K1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare( + int n, + int c, + int hi, + int wi, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k0_m_k1_grid_desc, + void* p_b_k0_n_k1_grid_desc, + void* p_c_m0_m1_m2_n_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi)); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x)); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW), + Number{}); + + const auto a_k0_m_k1_grid_desc = descs[I0]; + const auto b_k0_n_k1_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using AGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using BGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast*>(p_a_k0_m_k1_grid_desc) = + a_k0_m_k1_grid_desc; + *static_cast*>(p_b_k0_n_k1_grid_desc) = + b_k0_n_k1_grid_desc; + *static_cast(p_c_m0_m1_m2_n_grid_desc) = + c_m0_m1_m2_n_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); + + constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; + constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using BGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + constexpr auto c_m0_m1_m2_n_grid_desc_tmp = + GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); +}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp new file mode 100644 index 0000000000..30e4c518ce --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp @@ -0,0 +1,357 @@ +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr index_t MPerBlock = CK_PARAM_MPerBlock; +constexpr index_t NPerBlock = CK_PARAM_NPerBlock; +constexpr index_t KPerBlock = CK_PARAM_KPerBlock; + +constexpr index_t MPerWave = CK_PARAM_MPerWave; +constexpr index_t NPerWave = CK_PARAM_NPerWave; +constexpr index_t MRepeat = CK_PARAM_MRepeat; +constexpr index_t NRepeat = CK_PARAM_NRepeat; +constexpr index_t K1 = CK_PARAM_K1; + +using ABlockTransferThreadSliceLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterLengths_K0_M_K1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = + Sequence; +using ABlockTransferSrcAccessOrder = Sequence; + +constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; +constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; +constexpr index_t ABlockTransferDstScalarPerVector_K1 = + CK_PARAM_ABlockTransferDstScalarPerVector_K1; +constexpr bool AThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); + +using BBlockTransferThreadSliceLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterLengths_K0_N_K1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = + Sequence; +using BBlockTransferSrcAccessOrder = Sequence; + +constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; +constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; +constexpr index_t BBlockTransferDstScalarPerVector_K1 = + CK_PARAM_BBlockTransferDstScalarPerVector_K1; +constexpr bool BThreadTransferSrcResetCoordinateAfterRun = + static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); + +using CThreadTransferSrcDstAccessOrder = Sequence; +constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare( + int n, + int hi, + int wi, + int c, + int k, + int y, + int x, + int convStrideH, + int convStrideW, + int convDilationY, + int convDilationX, + int leftPadH, + int leftPadW, + int rightPadH, + int rightPadW, + void* p_a_k0_m_k1_grid_desc, + void* p_b_k0_n_k1_grid_desc, + void* p_c_m0_m1_m2_n_grid_desc, + void* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; + const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c)); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c)); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k)); + + const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(convStrideH, convStrideW), + make_tuple(convDilationY, convDilationX), + make_tuple(leftPadH, leftPadW), + make_tuple(rightPadH, rightPadW), + Number{}); + + const auto a_k0_m_k1_grid_desc = descs[I0]; + const auto b_k0_n_k1_grid_desc = descs[I1]; + const auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using BGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using AGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + if(hipThreadIdx_x == 0) + { + *static_cast*>(p_a_k0_m_k1_grid_desc) = + a_k0_m_k1_grid_desc; + *static_cast*>(p_b_k0_n_k1_grid_desc) = + b_k0_n_k1_grid_desc; + *static_cast(p_c_m0_m1_m2_n_grid_desc) = + c_m0_m1_m2_n_grid_desc; + *static_cast( + p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_k0_m_k1_grid_desc, + const void CONSTANT* p_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_hi_wi_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); + constexpr auto wei_k_y_x_c_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256)); + constexpr auto out_n_ho_wo_k_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256)); + + constexpr auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + Number{}); + + constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0]; + constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1]; + constexpr auto c_m_n_grid_desc = descs[I2]; + + using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp); + using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp); + using CMNGridDesc = decltype(c_m_n_grid_desc); + + using BGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}))); + + using AGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); + + using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}))); + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>; + using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + constexpr auto c_m0_m1_m2_n_grid_desc_tmp = + GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp); + + const auto a_k0_m_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_grid_desc); + const auto b_k0_n_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_grid_desc); + const auto c_m0_m1_m2_n_grid_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); +}; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..c1208ac3cb --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -0,0 +1,405 @@ +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_contraction_dlops_v1r2.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +constexpr DataTypeEnum_t ABDataTypeEnum = static_cast(CK_PARAM_ABDataTypeEnum); +constexpr DataTypeEnum_t AccDataTypeEnum = static_cast(CK_PARAM_AccDataTypeEnum); +constexpr DataTypeEnum_t CDataTypeEnum = static_cast(CK_PARAM_CDataTypeEnum); + +using FloatAB = typename get_datatype_from_enum::type; +using FloatAcc = typename get_datatype_from_enum::type; +using FloatC = typename get_datatype_from_enum::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr auto GN0 = Number{}; +constexpr auto GK1 = Number{}; + +constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; +constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; +constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; + +constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11; +constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; +constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; + +using BM10BN10ThreadClusterBM10Xs = Sequence; +using BM10BN10ThreadClusterBN10Xs = Sequence; + +using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>; +using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>; +using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; + +using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>; +using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>; +using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; + +using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>; +constexpr index_t CThreadTransferSrcDstVectorDim = 5; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HasMainKBlockLoop); +constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HasDoubleTailKBlockLoop); + +extern "C" __global__ void +convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_, + int C_, + int Hi_, + int Wi_, + int K_, + int Y_, + int X_, + int ConvStrideH_, + int ConvStrideW_, + int ConvDilationH_, + int ConvDilationW_, + int InLeftPadH_, + int InLeftPadW_, + int InRightPadH_, + int InRightPadW_, + void* p_desc_tuple) +{ + index_t N = static_cast(N_); + index_t C = static_cast(C_); + index_t Hi = static_cast(Hi_); + index_t Wi = static_cast(Wi_); + index_t K = static_cast(K_); + index_t Y = static_cast(Y_); + index_t X = static_cast(X_); + index_t ConvStrideH = static_cast(ConvStrideH_); + index_t ConvStrideW = static_cast(ConvStrideW_); + index_t ConvDilationH = static_cast(ConvDilationH_); + index_t ConvDilationW = static_cast(ConvDilationW_); + index_t InLeftPadH = static_cast(InLeftPadH_); + index_t InLeftPadW = static_cast(InLeftPadW_); + index_t InRightPadH = static_cast(InRightPadH_); + index_t InRightPadW = static_cast(InRightPadW_); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t Ho = + (Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1; + const index_t Wo = + (Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1; + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi)); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X)); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo)); + + const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(ConvStrideH, ConvStrideW), + make_tuple(ConvDilationH, ConvDilationW), + make_tuple(InLeftPadH, InLeftPadW), + make_tuple(InRightPadH, InRightPadW), + GN0, + GK1); + + const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); + using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); + using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); + + using AGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using BGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using CGridStepHacks = decltype(make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + + using BGridMoveSliceWindowStepHacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; + + using GridwiseContraction = + GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum_t::Set, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; + + if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0) + { + auto desc_tuple = + make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + a_grid_desc_gk0_gm0_gm1_gk1), + GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + b_grid_desc_gk0_gn0_gn1_gk1), + GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1), + GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1)); + + *static_cast(p_desc_tuple) = desc_tuple; + } +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_desc_tuple) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + GN0, + GK1); + + constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); + using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); + using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); + + using AGridStepHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using BGridStepHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using CGridStepHacks = decltype(make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 + + using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + + using BGridMoveSliceWindowStepHacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; + + using GridwiseContraction = + GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum_t::Set, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; + + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = + decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + a_grid_desc_gk0_gm0_gm1_gk1)); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = + decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + b_grid_desc_gk0_gn0_gn1_gk1)); + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = + decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1)); + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1)); + + using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{}, + BGridDesc_GK0_GN0_GN10_GN11_GK1{}, + CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, + CGridBlockCluster_BlockId_To_GM10_GN10{})); + + const auto desc_tuple = *reinterpret_cast( +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + // TODO: how to cast? + (const void*)p_desc_tuple +#pragma clang diagnostic pop + ); + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2]; + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3]; + + constexpr index_t shared_block_size = + GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseContraction::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant{}, + integral_constant{}); +}; diff --git a/external/rocm/include/bfloat16_dev.hpp b/external/rocm/include/bfloat16_dev.hpp new file mode 100644 index 0000000000..52d00346cf --- /dev/null +++ b/external/rocm/include/bfloat16_dev.hpp @@ -0,0 +1,125 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2019 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef BFLOAT16_DEVICE_HPP +#define BFLOAT16_DEVICE_HPP + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __HIP_PLATFORM_HCC__ +#define EXECUTION_SPECIFIER __device__ +#else +#define EXECUTION_SPECIFIER +#endif // MIOPEN_BACKEND_HIP + +typedef union +{ + uint u32; + ushort2 ushortx2; + +// Composable kernels are written in HIP language. The language doesnt support +// ushort2.hi or ushort2.low. +#ifdef __HIP_PLATFORM_HCC__ + ushort ushortvec[2]; +#endif // MIOPEN_BACKEND_HIP + float f32; +} cvt_bf16_fp32_t; + +EXECUTION_SPECIFIER float bfloat16_to_float(ushort src_val) +{ + cvt_bf16_fp32_t target_val; + +#ifdef __HIP_PLATFORM_HCC__ + target_val.ushortx2 = make_ushort2(0, src_val); +#else + target_val.ushortx2 = (ushort2)(0, src_val); +#endif + + return target_val.f32; +} + +EXECUTION_SPECIFIER ushort float_to_bfloat16(float src_val) +{ + cvt_bf16_fp32_t target_val; + target_val.f32 = src_val; + // BF16 round and NaN preservation code matches + // https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h + if((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + if((target_val.u32 & 0xffff) != 0) + { + target_val.u32 |= 0x10000; // Preserve signaling NaN + } + } + else + { +#ifdef MIOPEN_USE_RNE_BFLOAT16 +// When the exponent bits are not all 1s, then the value is zero, normal, +// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus +// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). +// This causes the bfloat16's mantissa to be incremented by 1 if the 16 +// least significant bits of the float mantissa are greater than 0x8000, +// or if they are equal to 0x8000 and the least significant bit of the +// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when +// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already +// has the value 0x7f, then incrementing it causes it to become 0x00 and +// the exponent is incremented by one, which is the next higher FP value +// to the unrounded bfloat16 value. When the bfloat16 value is subnormal +// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up +// to a normal value with an exponent of 0x01 and a mantissa of 0x00. +// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, +// incrementing it causes it to become an exponent of 0xFF and a mantissa +// of 0x00, which is Inf, the next higher value to the unrounded value. +#ifdef __HIP_PLATFORM_HCC__ + target_val.u32 += (0x7fff + (target_val.ushortvec[1] & 1)); +#else + target_val.u32 += + (0x7fff + (target_val.ushortx2.hi & 1)); // Round to nearest, round to even +#endif // MIOPEN_BACKEND_HIP +#endif // MIOPEN_USE_RNE_BFLOAT16 + } + +#ifdef __HIP_PLATFORM_HCC__ + return target_val.ushortvec[1]; +#else + return target_val.ushortx2.hi; +#endif // MIOPEN_BACKEND_HIP +} + +#ifdef __cplusplus +} +#endif + +#endif // BFLOAT16_DEVICE_HPP diff --git a/host/CMakeLists.txt b/host/CMakeLists.txt new file mode 100644 index 0000000000..30cc14d8ca --- /dev/null +++ b/host/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(host_tensor) +add_subdirectory(driver_offline) diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt new file mode 100644 index 0000000000..fec11e99af --- /dev/null +++ b/host/driver_offline/CMakeLists.txt @@ -0,0 +1,21 @@ +include_directories(BEFORE + include + ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/host/solver/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation + ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform + ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver + ${PROJECT_SOURCE_DIR}/external/rocm/include +) + +set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) +set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) + +add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) +add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) + +target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) +target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..7bd82bf6d5 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,330 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + I0, + I0, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<1, 3, 7, 0, 2, 4, 5, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..0ebf8571f4 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,306 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + I0, + I0, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + + constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy +#if 0 + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, +#else + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, +#endif + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + true // CAccessOrderMRepeatNRepeat + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_m1_m2_n_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e6554cf0fe --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,201 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_dlops_v1r2.hpp" + +template +void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // cdata = 64, BlockSize = 256, 128x128x8 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; + + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); + + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + const auto wei_gemmk_gemmm_grid_desc = descs[I0]; + const auto in_gemmk_gemmn_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_dlops_v1r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk_gemmm_grid_desc), + decltype(in_gemmk_gemmn_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlockM1, + GemmNPerBlockN1, + GemmKPerBlock, + GemmM1PerThreadM111, + GemmN1PerThreadN111, + GemmKPerThread, + GemmM11N11ThreadClusterM1100, + GemmM11N11ThreadClusterN1100, + GemmM11N11ThreadClusterM1101, + GemmM11N11ThreadClusterN1101, + GemmABlockTransferThreadSliceLengths_K_M0_M1, + GemmABlockTransferThreadClusterLengths_K_M0_M1, + Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder + Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder + 0, // ABlockTransferSrcVectorDim + GemmABlockTransferSrcScalarPerVector_K, + GemmABlockTransferDstScalarPerVector_M1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_K_N0_N1, + GemmBBlockTransferThreadClusterLengths_K_N0_N1, + Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + GemmBBlockTransferSrcScalarPerVector_N1, + GemmBBlockTransferDstScalarPerVector_N1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_N11, + decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), + decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>( + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gemmk_gemmm_grid_desc, + in_gemmk_gemmn_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk_gemmm0_gemmn1_grid_step_hacks, + in_gemmk_gemmn0_gemmn1_grid_step_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, + wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks, + in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..4a9d01081c --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,280 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" + +template +void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#endif + + const auto descs = +#if 1 + transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad +#else + transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 +#endif + ( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + for(index_t i = 0; i < 5; ++i) + { +#if 0 + float ave_time = launch_kernel_gemm_xdlops_v1 +#else + float ave_time = launch_kernel_gemm_xdlops_v2 +#endif + , + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_KPack, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<1, 0, 2>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_KPack, + false, // don't move back src coordinate after threadwise copy, which will be fused + // with MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadTransferDstScalarPerVector_GemmN1, + decltype(descs[I4]), + decltype(descs[I5]), + decltype(descs[I6]), + decltype(descs[I7]), + decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + descs[I0], + descs[I1], + descs[I2], + descs[I3], + descs[I4], + descs[I5], + descs[I6], + descs[I7], + descs[I8], + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..40685e81cf --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,273 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_dlops_v1r3.hpp" + +template +void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [128, 128, 8, 1] for fp32 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 1; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 8, 2] for fp16 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 2; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 8, 4] for i8 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 4; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>; + using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10 + Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11 + Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0 + Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10 + Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10 + Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11 + Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0 + Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10 + Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11 + + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; + + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_dlops_v1r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlockM1, + GemmNPerBlockN1, + GemmKPerBlock, + GemmM1PerThreadM111, + GemmN1PerThreadN111, + GemmKPerThread, + GemmM11N11ThreadClusterM110Xs, + GemmM11N11ThreadClusterN110Xs, + GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1, + GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1, + Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1, + GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1, + Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_N11, + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks), + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>( + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..695ffeeb36 --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,197 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<1, 0, 2>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..141a326574 --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,229 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r2.hpp" + +template +void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1>, + 2, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..692751bfb3 --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,302 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..7067291c8a --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,354 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto out_m0_m1_m2_n_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0>{}), // 7+: N1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 + + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + in_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + out_m0_m1_m2_n_grid_step_hacks, + in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..b5e5f91d59 --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,190 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" +#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp" + +template +void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t /* nrepeat */) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = out_n_k_ho_wo_lengths[I0]; + const auto K = out_n_k_ho_wo_lengths[I1]; + const auto C = wei_k_c_y_x_lengths[I1]; + + const auto Hi = in_n_c_hi_wi_lengths[I2]; + const auto Wi = in_n_c_hi_wi_lengths[I3]; + + const auto Ho = out_n_k_ho_wo_lengths[I2]; + const auto Wo = out_n_k_ho_wo_lengths[I3]; + + const auto Y = wei_k_c_y_x_lengths[I2]; + const auto X = wei_k_c_y_x_lengths[I3]; + + const auto C0 = C / Number{}; + const auto C1 = Number{}; + + const auto K0 = K / Number{}; + const auto K1 = Number{}; + + Tensor in_n_c0_hi_wi_c1( + HostTensorDescriptor(std::initializer_list{N, C0, Hi, Wi, C1})); + Tensor wei_k_c0_y_x_c1( + HostTensorDescriptor(std::initializer_list{K, C0, Y, X, C1})); + Tensor out_n_k0_ho_wo_k1( + HostTensorDescriptor(std::initializer_list{N, K0, Ho, Wo, K1})); + + auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { + in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = + in_n_c_hi_wi(n, c, hi, wi); + }; + + auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) { + wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) = + wei_k_c_y_x(k, c, y, x); + }; + + make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); + make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)(); + + DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * + in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); + DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); + DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * + out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); + + in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); + wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); + + const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi)); + const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X)); + const auto out_n_k0_ho_wo_k1_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)); + +#if 1 + // cdata = 64, BlockSize = 64, 16x8x32x4 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 16; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 32; + constexpr index_t EPerBlock = 1; + + constexpr index_t KPerThread = KPerBlock; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = EPerBlock; + + using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>; + using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>; + + constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K = 1; + + constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; + + constexpr index_t CThreadTransferDstScalarPerVector_W = 16; + + static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); +#else + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 16; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 32; + constexpr index_t EPerBlock = 1; + + constexpr index_t KPerThread = 16; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = EPerBlock; + + using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; + using ABlockTransferThreadClusterLengths_E_K = Sequence; + + constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K = 1; + + constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; + + constexpr index_t CThreadTransferDstScalarPerVector_W = K1; + + static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); +#endif + + constexpr auto conv_driver = +#if 0 + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad +#else + DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad +#endif + ::type, + TAcc, + TOut, + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + BThreadTransferSrcScalarPerVector_W, + CThreadTransferDstScalarPerVector_W>{}; + + conv_driver.Run(wei_k_c0_y_x_desc, + in_n_c0_hi_wi_desc, + out_n_k0_ho_wo_k1_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), + static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer())); + + out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); + + auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { + out_n_k_ho_wo(n, k, ho, wo) = + out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize); + }; + + make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)(); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e1b7c5486c --- /dev/null +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,241 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" +#include "driver_contraction_dlops_v1r2.hpp" + +template +void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GN0 = 4; + constexpr index_t GK1 = 1; + + constexpr index_t GM1PerBlockGM11 = 128; + constexpr index_t GN1PerBlockGN11 = 32; + constexpr index_t GK0PerBlock = 8; + + constexpr index_t BM1PerThreadBM11 = 4; + constexpr index_t BN1PerThreadBN11 = 4; + constexpr index_t BK0PerThread = 1; + + using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; + using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; + + using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; + + using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>; + + using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>; + using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; + + using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + + constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; +#elif 1 + // [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GN0 = 4; + constexpr index_t GK1 = 2; + + constexpr index_t GM1PerBlockGM11 = 128; + constexpr index_t GN1PerBlockGN11 = 32; + constexpr index_t GK0PerBlock = 8; + + constexpr index_t BM1PerThreadBM11 = 4; + constexpr index_t BN1PerThreadBN11 = 4; + constexpr index_t BK0PerThread = 1; + + using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>; + using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>; + + using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; + using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; + + using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>; + + using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>; + using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; + + using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>; + + constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; +#endif + + const auto descs = + transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x, + in_desc_n_c_hi_wi, + out_desc_n_k_ho_wo, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + Number{}); + + const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 + + constexpr auto in_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 + + constexpr auto out_grid_step_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1 + + constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; + + constexpr auto in_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_contraction_dlops_v1r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(wei_grid_desc_gk0_gm0_gm1_gk1), + decltype(in_grid_desc_gk0_gn0_gn1_gk1), + decltype(out_grid_desc_gm0_gm1_gn0_gn1), + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder + Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder + Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder + Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + CThreadTransferDstScalarPerVector_BN1, + decltype(wei_grid_step_hacks), + decltype(in_grid_step_hacks), + decltype(out_grid_step_hacks), + decltype(wei_grid_move_slice_window_step_hacks), + decltype(in_grid_move_slice_window_step_hacks)>( + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_grid_desc_gk0_gm0_gm1_gk1, + in_grid_desc_gk0_gn0_gn1_gk1, + out_grid_desc_gm0_gm1_gn0_gn1, + wei_grid_step_hacks, + in_grid_step_hacks, + out_grid_step_hacks, + wei_grid_move_slice_window_step_hacks, + in_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp b/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp new file mode 100644 index 0000000000..d207728a2e --- /dev/null +++ b/host/driver_offline/include/driver_contraction_dlops_v1r2.hpp @@ -0,0 +1,286 @@ +#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP +#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_contraction_dlops_v1r2.hpp" + +template +__host__ float +driver_contraction_dlops_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseContraction = + GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + CGlobalMemoryDataOperation, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks>; + + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + + if(!GridwiseContraction::CheckValidity( + a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1)) + { + throw std::runtime_error("wrong! " + "GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_" + "GM0_GM1_GN0_GN1 has invalid setting"); + } + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = + GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1); + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = + GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1); + + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1); + + // c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = + GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1); + + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); + + // c_grid_block_cluster_blockid_to_gm10_gn10 + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = + GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1); + + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(c_grid_block_cluster_blockid_to_gm10_gn10); + + const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1); + + const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0); + + const bool has_double_tail_k_block_loop = + GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0); + + { + std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{" + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl; + + std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{" + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl; + + std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl; + } + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_contraction_dlops_v1r2< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_contraction_dlops_v1r2< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_contraction_dlops_v1r2< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); + } + else + { + const auto kernel = kernel_contraction_dlops_v1r2< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); + } + + return ave_time; +} +#endif diff --git a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..efd4ce6a19 --- /dev/null +++ b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,349 @@ +#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP +#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v2.hpp" +#include "gridwise_operation_wrapper.hpp" + +template +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad +{ + template + __host__ void Run(const ck::TensorDescriptor& wei_k_c_y_x_global_desc, + const ck::TensorDescriptor& in_n_c_hi_wi_global_desc, + const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto K = wei_k_c_y_x_global_desc.GetLength(I0); + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_e_k_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Ho), + make_pass_through_transform(Wo)), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pass_through_transform(Ho), + make_pass_through_transform(Wo)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto E = C * Y * X; + + if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 && + (E % EPerBlock) == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_e_k_global_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; + + constexpr auto b_e_n_ho_wo_global_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + +#if 1 + // GEMM + using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum_t::Set, + decltype(wei_e_k_global_desc), + decltype(in_e_n_ho_wo_global_desc), + decltype(out_k_n_ho_wo_global_desc), + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 3, 1>, + 3, + BThreadTransferSrcScalarPerVector_W, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<0, 2, 3, 1>, + 0, + CThreadTransferDstScalarPerVector_W, + decltype(a_e_k_global_step_hacks), + decltype(b_e_n_ho_wo_global_step_hacks), + decltype(c_k_n_ho_wo_global_tensor_step_hacks), + decltype(a_e_k_global_move_slice_window_step_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; + + const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; + + const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + std::cout << "has_main_k_block_loop: " << has_main_k_block_loop + << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop + << std::endl; + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = + static_cast(calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k0_ho_wo_k1_global_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#endif + } +}; +#endif diff --git a/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp new file mode 100644 index 0000000000..70f73cbf4a --- /dev/null +++ b/host/driver_offline/include/driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp @@ -0,0 +1,364 @@ +#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP +#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v2.hpp" +#include "gridwise_operation_wrapper.hpp" + +template +struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad +{ + template + __host__ void Run(const ck::TensorDescriptor& wei_k_c_y_x_global_desc, + const ck::TensorDescriptor& in_n_c_hi_wi_global_desc, + const ck::TensorDescriptor& out_n_k0_ho_wo_k1_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2); + const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3); + + const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4); + + const auto K = wei_k_c_y_x_global_desc.GetLength(I0); + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock; + const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock; + + const auto OutRightPadH = Hop - Ho; + const auto OutRightPadW = Wop - Wo; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH; + const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW; + + std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW + << std::endl; + std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW + << std::endl; + + // weight tensor + const auto wei_e_k_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Hop), + make_pass_through_transform(Wop)), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)), + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(N), + make_pad_transform(Ho, 0, OutRightPadH), + make_pad_transform(Wo, 0, OutRightPadW)), + make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto E = C * Y * X; + + std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl; + + if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 && + (E % EPerBlock) == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_e_k_global_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{}; + + constexpr auto b_e_n_ho_wo_global_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + // GEMM + using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperationEnum_t::Set, + decltype(wei_e_k_global_desc), + decltype(in_e_n_ho_wo_global_desc), + decltype(out_k_n_hop_wop_global_desc), + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 3, 1>, + 3, + BThreadTransferSrcScalarPerVector_W, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<0, 2, 3, 1>, + 0, + CThreadTransferDstScalarPerVector_W, + decltype(a_e_k_global_step_hacks), + decltype(b_e_n_ho_wo_global_step_hacks), + decltype(c_k_n_ho_wo_global_tensor_step_hacks), + decltype(a_e_k_global_move_slice_window_step_hack), + decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>; + + const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N; + + const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + std::cout << "has_main_k_block_loop: " << has_main_k_block_loop + << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop + << std::endl; + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + wei_e_k_global_desc, + p_wei_global, + in_e_n_ho_wo_global_desc, + p_in_global, + out_k_n_hop_wop_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = + static_cast(calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k0_ho_wo_k1_global_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } +}; +#endif diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp b/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp new file mode 100644 index 0000000000..bf5f7f1c0f --- /dev/null +++ b/host/driver_offline/include/driver_gemm_dlops_v1r2.hpp @@ -0,0 +1,413 @@ +#ifndef DRIVER_GEMM_DLOPS_V1R2 +#define DRIVER_GEMM_DLOPS_V1R2 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v1r2.hpp" + +template +__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2; + + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting"); + } + + const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + + using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); + using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); + + { + std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " + << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " + << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc)); + DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc); + b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + + return ave_time; +#endif +} +#endif diff --git a/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp b/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp new file mode 100644 index 0000000000..4470918820 --- /dev/null +++ b/host/driver_offline/include/driver_gemm_dlops_v1r3.hpp @@ -0,0 +1,418 @@ +#ifndef DRIVER_GEMM_DLOPS_V1R3 +#define DRIVER_GEMM_DLOPS_V1R3 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_dlops_v1r3.hpp" + +template +__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = + GridwiseGemmDlops_km_kn_mn_v1r3; + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting"); + } + + const auto a_k0_m0_m1_k1_grid_desc = + GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); + const auto b_k0_n0_n1_k1_grid_desc = + GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); + + using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); + using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + { + std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc)); + DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc); + b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else + { + const auto kernel = + kernel_gemm_dlops_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + + return ave_time; +#endif +} +#endif diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp new file mode 100644 index 0000000000..edfce52a19 --- /dev/null +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -0,0 +1,191 @@ +#ifndef DRIVER_GEMM_XDLOPS_V2R3 +#define DRIVER_GEMM_XDLOPS_V2R3 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +template +__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + using GridwiseGemm = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + { + std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " + << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); + + const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); + + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); + +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); + DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); + DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); + DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); + + a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); + b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); + c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); + c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); + + float ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); +#endif + return ave_time; +} +#endif diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp new file mode 100644 index 0000000000..67cea94813 --- /dev/null +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -0,0 +1,321 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv_bwd_data.hpp" +#include "device_tensor.hpp" +#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" + +#define USE_MODE 1 +#define USE_CONV_BWD_V4R1_XDL_NHWC 1 +#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 + +enum ConvBackwardDataAlgo +{ + V4R1XDLNHWC, + V4R1R2XDLNHWC, +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardDataAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t N = std::stoi(argv[7]); + const index_t K = std::stoi(argv[8]); + const index_t C = std::stoi(argv[9]); + const index_t Y = std::stoi(argv[10]); + const index_t X = std::stoi(argv[11]); + const index_t Hi = std::stoi(argv[12]); + const index_t Wi = std::stoi(argv[13]); + + const index_t conv_stride_h = std::stoi(argv[14]); + const index_t conv_stride_w = std::stoi(argv[15]); + const index_t conv_dilation_h = std::stoi(argv[16]); + const index_t conv_dilation_w = std::stoi(argv[17]); + const index_t in_left_pad_h = std::stoi(argv[18]); + const index_t in_left_pad_w = std::stoi(argv[19]); + const index_t in_right_pad_h = std::stoi(argv[20]); + const index_t in_right_pad_w = std::stoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardDataAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t Hi = 71; + constexpr index_t Wi = 71; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + const index_t conv_stride_h = 2; + const index_t conv_stride_w = 2; + const index_t conv_dilation_h = 1; + const index_t conv_dilation_w = 1; + const index_t in_left_pad_h = 1; + const index_t in_left_pad_w = 1; + const index_t in_right_pad_h = 1; + const index_t in_right_pad_w = 1; + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#endif + +#if 0 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + if(layout == ConvTensorLayout::NCHW) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + } + else if(layout == ConvTensorLayout::NHWC) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + } + else + { + throw std::runtime_error("wrong! not implemented"); + } + + Tensor in_host(in_lengths_host); + Tensor in_device(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor out(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in_host.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + break; + default: + out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + + auto f_make_for_device_nhwc = [&]() { +#if USE_MODE + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_BWD_V4R1_XDL_NHWC + if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } +#endif + +#if USE_CONV_BWD_V4R1R2_XDL_NHWC + if(algo == ConvBackwardDataAlgo::V4R1R2XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_backward_data(in_host, + wei, + out, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(in_host, in_device); + + if(do_log) + { + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_host : ", in_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_device: ", in_device.mData, ",") << std::endl; + } + } +} diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp new file mode 100644 index 0000000000..32c33003c5 --- /dev/null +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -0,0 +1,474 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv.hpp" +#include "device_tensor.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" + +#define USE_MODE 1 +#define USE_CONV_FWD_V4R4_NCHW 1 +#define USE_CONV_FWD_V4R4R2_NHWC 1 +#define USE_CONV_FWD_V6R1_NCHW 0 +#define USE_CONV_FWD_V5R1_NCHW 0 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 + +enum ConvForwardAlgo +{ + V4R4NCHW, // 0 + V4R4R2NHWC, // 1 + V6R1NCHW, // 2 + V5R1NCHW, // 3 + V4R4R2XDLNCHW, // 4 + V4R4R4XDLNHWC // 5 +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvForwardAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t N = std::stoi(argv[7]); + const index_t K = std::stoi(argv[8]); + const index_t C = std::stoi(argv[9]); + const index_t Y = std::stoi(argv[10]); + const index_t X = std::stoi(argv[11]); + const index_t Hi = std::stoi(argv[12]); + const index_t Wi = std::stoi(argv[13]); + + const index_t conv_stride_h = std::stoi(argv[14]); + const index_t conv_stride_w = std::stoi(argv[15]); + const index_t conv_dilation_h = std::stoi(argv[16]); + const index_t conv_dilation_w = std::stoi(argv[17]); + const index_t in_left_pad_h = std::stoi(argv[18]); + const index_t in_left_pad_w = std::stoi(argv[19]); + const index_t in_right_pad_h = std::stoi(argv[20]); + const index_t in_right_pad_w = std::stoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvForwardAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t Hi = 71; + constexpr index_t Wi = 71; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + const index_t conv_stride_h = 2; + const index_t conv_stride_w = 2; + const index_t conv_dilation_h = 1; + const index_t conv_dilation_w = 1; + const index_t in_left_pad_h = 1; + const index_t in_left_pad_w = 1; + const index_t in_right_pad_h = 1; + const index_t in_right_pad_w = 1; + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#endif + +#if 1 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + if(layout == ConvTensorLayout::NCHW) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + } + else if(layout == ConvTensorLayout::NHWC) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + } + else + { + std::runtime_error("wrong! not implemented"); + } + + Tensor in(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor out_host(out_lengths_host); + Tensor out_device(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + + auto f_make_for_device_nchw = [&]() { +#if USE_MODE + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + auto f_make_for_device_nhwc = [&]() { +#if USE_MODE + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_FWD_V4R4_NCHW + if(algo == ConvForwardAlgo::V4R4NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R2_NHWC + if(algo == ConvForwardAlgo::V4R4R2NHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V6R1_NCHW + if(algo == ConvForwardAlgo::V6R1NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V5R1_NCHW + if(algo == ConvForwardAlgo::V5R1NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R2_XDL_NCHW + if(algo == ConvForwardAlgo::V4R4R2XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R4_XDL_NHWC + if(algo == ConvForwardAlgo::V4R4R4XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution(in, + wei, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(out_host, out_device); + + if(do_log) + { + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; + } + } +} diff --git a/host/host_tensor/CMakeLists.txt b/host/host_tensor/CMakeLists.txt new file mode 100644 index 0000000000..3dcecf64e1 --- /dev/null +++ b/host/host_tensor/CMakeLists.txt @@ -0,0 +1,21 @@ +include_directories(BEFORE + include +) + +set(HOST_TENSOR_SOURCE + src/host_tensor.cpp; + src/device.cpp; +) + +## the library target +add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE}) + +target_include_directories(host_tensor SYSTEM PUBLIC $) + +target_link_libraries(host_tensor PRIVATE hip::device) +target_link_libraries(host_tensor INTERFACE hip::host) + +target_compile_features(host_tensor PUBLIC) +set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) + +install(TARGETS host_tensor LIBRARY DESTINATION lib) diff --git a/host/host_tensor/include/conv_common.hpp b/host/host_tensor/include/conv_common.hpp new file mode 100644 index 0000000000..4bf2c23494 --- /dev/null +++ b/host/host_tensor/include/conv_common.hpp @@ -0,0 +1,86 @@ +#ifndef CONV_COMMON_HPP +#define CONV_COMMON_HPP + +#include "tensor_descriptor.hpp" + +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; + +template +constexpr auto get_convolution_output_default_4d_tensor_descriptor( + const ck::TensorDescriptor& in_desc, + const ck::TensorDescriptor& wei_desc, + const ConvStrides& conv_strides, + const ConvDilations conv_dilations, + const LeftPads& left_pads, + const RightPads& right_pads) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + assert(in_desc.GetNumOfDimension() == 4); + assert(wei_desc.GetNumOfDimension() == 4); + assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1)); + + const auto N = in_desc.GetLength(I0); + const auto Hi = in_desc.GetLength(I2); + const auto Wi = in_desc.GetLength(I3); + + const auto K = wei_desc.GetLength(I0); + const auto Y = wei_desc.GetLength(I2); + const auto X = wei_desc.GetLength(I3); + + const auto LeftPadH = left_pads[I0]; + const auto LeftPadW = left_pads[I1]; + + const auto RightPadH = right_pads[I0]; + const auto RightPadW = right_pads[I1]; + + const auto YEff = (Y - I1) * conv_dilations[I0] + I1; + const auto XEff = (X - I1) * conv_dilations[I1] + I1; + + const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1; + const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1; + + return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo)); +} + +template +constexpr std::size_t +calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const index_t N = out_desc.GetLength(I0); + const index_t K = out_desc.GetLength(I1); + const index_t Ho = out_desc.GetLength(I2); + const index_t Wo = out_desc.GetLength(I3); + + const index_t C = wei_desc.GetLength(I1); + const index_t Y = wei_desc.GetLength(I2); + const index_t X = wei_desc.GetLength(I3); + + return std::size_t(2) * N * K * Ho * Wo * C * Y * X; +} + +#endif diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp new file mode 100644 index 0000000000..e2cba94100 --- /dev/null +++ b/host/host_tensor/include/device.hpp @@ -0,0 +1,80 @@ +#ifndef DEVICE_HPP +#define DEVICE_HPP + +#include +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" + +struct DeviceMem +{ + DeviceMem() = delete; + DeviceMem(std::size_t mem_size); + void* GetDeviceBuffer(); + void ToDevice(const void* p); + void FromDevice(void* p); + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; + +struct KernelTimerImpl; + +struct KernelTimer +{ + KernelTimer(); + ~KernelTimer(); + void Start(); + void End(); + float GetElapsedTime() const; + + std::unique_ptr impl; +}; + +using device_stream_t = hipStream_t; + +template +void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) +{ + hipStream_t stream_id = nullptr; + + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); +} + +template +float launch_and_time_kernel( + F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) +{ + KernelTimer timer; + + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up\n"); + + hipStream_t stream_id = nullptr; + + // warm up + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); + + printf("Start running %d times...\n", nrepeat); + + timer.Start(); + + for(int i = 0; i < nrepeat; ++i) + { + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); + } + + timer.End(); + + return timer.GetElapsedTime() / nrepeat; +} + +#endif diff --git a/host/host_tensor/include/device_tensor.hpp b/host/host_tensor/include/device_tensor.hpp new file mode 100644 index 0000000000..1a7a34a4cf --- /dev/null +++ b/host/host_tensor/include/device_tensor.hpp @@ -0,0 +1,9 @@ +#pragma once +#include "host_tensor.hpp" +#include "common_header.hpp" + +template +void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout) +{ + ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os); +} diff --git a/host/host_tensor/include/host_conv.hpp b/host/host_tensor/include/host_conv.hpp new file mode 100644 index 0000000000..c1228f4832 --- /dev/null +++ b/host/host_tensor/include/host_conv.hpp @@ -0,0 +1,324 @@ +#pragma once +#include "host_tensor.hpp" + +template +void host_direct_convolution(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(wei(k, c, y, x)); + } + } + } + } + out(n, k, ho, wo) = v; + }; + + auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(wei(k, y, x, c)); + } + } + } + } + out(n, ho, wo, k) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_nhwc, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} + +template +void host_winograd_3x3_convolution(const Tensor& in_nchw, + const Tensor& wei_kcyx, + Tensor& out_nkhw, + InLeftPads, + InRightPads) +{ + using namespace ck; + + constexpr std::size_t HoPerTile = 2; + constexpr std::size_t WoPerTile = 2; + + std::size_t N = in_nchw.mDesc.GetLengths()[0]; + std::size_t C = in_nchw.mDesc.GetLengths()[1]; + + std::size_t K = wei_kcyx.mDesc.GetLengths()[0]; + std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; + std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; + + std::size_t Ho = out_nkhw.mDesc.GetLengths()[2]; + std::size_t Wo = out_nkhw.mDesc.GetLengths()[3]; + + index_t h_pad_low = InLeftPads{}.Get(Number<0>{}); + index_t w_pad_low = InLeftPads{}.Get(Number<1>{}); + + std::size_t HiPerTile = HoPerTile + Y - 1; + std::size_t WiPerTile = WoPerTile + X - 1; + + std::size_t HTile = (Ho + HoPerTile - 1) / HoPerTile; + std::size_t WTile = (Wo + WoPerTile - 1) / WoPerTile; + + Tensor in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile}); + Tensor in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile}); + Tensor wei_transform({K, C, HiPerTile, WiPerTile}); + Tensor out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile}); + Tensor out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile}); + + auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) { + for(int j = 0; j < HiPerTile; ++j) + { + int hi = HoPerTile * htile + j - h_pad_low; + for(int i = 0; i < WiPerTile; ++i) + { + int wi = WoPerTile * wtile + i - w_pad_low; + + if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && + wi < in_nchw.mDesc.GetLengths()[3]) + { + in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi); + } + else + { + in_hold(n, c, htile, wtile, j, i) = TIn(0); + } + } + } + }; + + auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) { + in_transform(n, c, htile, wtile, 0, 0) = + in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) - + in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 1) = + in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) - + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 2) = + -in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 0, 3) = + in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3); + + in_transform(n, c, htile, wtile, 1, 0) = + in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 1) = + in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 2) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 1, 3) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); + + in_transform(n, c, htile, wtile, 2, 0) = + -in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 1) = + -in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 2) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2); + in_transform(n, c, htile, wtile, 2, 3) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) + + in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3); + + in_transform(n, c, htile, wtile, 3, 0) = + in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 1) = + in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) - + in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 2) = + -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) + + in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2); + in_transform(n, c, htile, wtile, 3, 3) = + in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) - + in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3); + }; + + auto f_wei_transform = [&](auto k, auto c) { + wei_transform(k, c, 0, 0) = double(wei_kcyx(k, c, 0, 0)); + wei_transform(k, c, 0, 1) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + + 0.5 * double(wei_kcyx(k, c, 0, 1)) + + 0.5 * double(wei_kcyx(k, c, 0, 2)); + wei_transform(k, c, 0, 2) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - + 0.5 * double(wei_kcyx(k, c, 0, 1)) + + 0.5 * double(wei_kcyx(k, c, 0, 2)); + wei_transform(k, c, 0, 3) = double(wei_kcyx(k, c, 0, 2)); + + wei_transform(k, c, 1, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) + + 0.5 * double(wei_kcyx(k, c, 1, 0)) + + 0.5 * double(wei_kcyx(k, c, 2, 0)); + wei_transform(k, c, 1, 1) = + 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + + 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) + + 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + + 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + + 0.25 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 1, 2) = + 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + + 0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) - + 0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) + + 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + + 0.25 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 1, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) + + 0.5 * double(wei_kcyx(k, c, 1, 2)) + + 0.5 * double(wei_kcyx(k, c, 2, 2)); + + wei_transform(k, c, 2, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) - + 0.5 * double(wei_kcyx(k, c, 1, 0)) + + 0.5 * double(wei_kcyx(k, c, 2, 0)); + wei_transform(k, c, 2, 1) = + 0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) + + 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) - + 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + + 0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) + + 0.25 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 2, 2) = + 0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) + + 0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) + + 0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) + + 0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) + + 0.25 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 2, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) - + 0.5 * double(wei_kcyx(k, c, 1, 2)) + + 0.5 * double(wei_kcyx(k, c, 2, 2)); + + wei_transform(k, c, 3, 0) = double(wei_kcyx(k, c, 2, 0)); + wei_transform(k, c, 3, 1) = 0.5 * double(wei_kcyx(k, c, 2, 0)) + + 0.5 * double(wei_kcyx(k, c, 2, 1)) + + 0.5 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 3, 2) = 0.5 * double(wei_kcyx(k, c, 2, 0)) - + 0.5 * double(wei_kcyx(k, c, 2, 1)) + + 0.5 * double(wei_kcyx(k, c, 2, 2)); + wei_transform(k, c, 3, 3) = double(wei_kcyx(k, c, 2, 2)); + }; + + auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) { + for(int j = 0; j < HiPerTile; ++j) + { + for(int i = 0; i < WiPerTile; ++i) + { + double v = 0; + for(int c = 0; c < C; ++c) + { + v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i); + } + + out_transform(n, k, htile, wtile, j, i) = v; + } + } + }; + + auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) { + out_hold(n, k, htile, wtile, 0, 0) = + out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) + + out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) + + out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) + + out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) + + out_transform(n, k, htile, wtile, 2, 2); + out_hold(n, k, htile, wtile, 0, 1) = + out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) - + out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) - + out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) + + out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - + out_transform(n, k, htile, wtile, 2, 3); + out_hold(n, k, htile, wtile, 1, 0) = + out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) + + out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) - + out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) - + out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) - + out_transform(n, k, htile, wtile, 3, 2); + out_hold(n, k, htile, wtile, 1, 1) = + out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) - + out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) + + out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) - + out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) + + out_transform(n, k, htile, wtile, 3, 3); + }; + + auto f_out = [&](auto n, auto k, auto htile, auto wtile) { + for(int j = 0; j < HoPerTile; ++j) + { + std::size_t ho = HoPerTile * htile + j; + for(int i = 0; i < WoPerTile; ++i) + { + std::size_t wo = WoPerTile * wtile + i; + out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); + } + } + }; + + std::size_t num_thread = std::thread::hardware_concurrency(); + + make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread); + make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread); + make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread); +} diff --git a/host/host_tensor/include/host_conv_bwd_data.hpp b/host/host_tensor/include/host_conv_bwd_data.hpp new file mode 100644 index 0000000000..ca23422e23 --- /dev/null +++ b/host/host_tensor/include/host_conv_bwd_data.hpp @@ -0,0 +1,135 @@ +#pragma once +#include "host_tensor.hpp" + +template +void host_direct_convolution_backward_data(Tensor& in, + const Tensor& wei, + const Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& /* in_right_pads */, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I2]; + std::size_t X = wei.mDesc.GetLengths()[I3]; + + std::size_t Ho = out.mDesc.GetLengths()[I2]; + std::size_t Wo = out.mDesc.GetLengths()[I3]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, k, ho, wo) * wei(k, c, y, x); + } + } + } + } + } + } + } + + in(n, c, hi, wi) = v; + }; + + auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I1]; + std::size_t X = wei.mDesc.GetLengths()[I2]; + + std::size_t Ho = out.mDesc.GetLengths()[I1]; + std::size_t Wo = out.mDesc.GetLengths()[I2]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, ho, wo, k) * wei(k, y, x, c); + } + } + } + } + } + } + } + + in(n, hi, wi, c) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_nchw, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_nhwc, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} diff --git a/host/host_tensor/include/host_tensor.hpp b/host/host_tensor/include/host_tensor.hpp new file mode 100644 index 0000000000..06aed0a0c1 --- /dev/null +++ b/host/host_tensor/include/host_tensor.hpp @@ -0,0 +1,322 @@ +#ifndef HOST_TENSOR_HPP +#define HOST_TENSOR_HPP + +#include +#include +#include +#include +#include +#include +#include + +template +std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << v; + } + return os; +} + +template +std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << static_cast(v); + } + return os; +} + +typedef enum +{ + Half = 0, + Float = 1, +} DataType_t; + +template +struct DataType; + +template <> +struct DataType : std::integral_constant +{ +}; + +template +auto call_f_unpack_args_impl(F f, T args, std::index_sequence) +{ + return f(std::get(args)...); +} + +template +auto call_f_unpack_args(F f, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return call_f_unpack_args_impl(f, args, std::make_index_sequence{}); +} + +template +auto construct_f_unpack_args_impl(T args, std::index_sequence) +{ + return F(std::get(args)...); +} + +template +auto construct_f_unpack_args(F, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return construct_f_unpack_args_impl(args, std::make_index_sequence{}); +} + +struct HostTensorDescriptor +{ + HostTensorDescriptor() = delete; + + template + HostTensorDescriptor(std::vector lens); + + template + HostTensorDescriptor(std::vector lens, std::vector strides); + + void CalculateStrides(); + + template + HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + template + HostTensorDescriptor(const Range1& lens, const Range2& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + std::size_t GetNumOfDimension() const; + std::size_t GetElementSize() const; + std::size_t GetElementSpace() const; + + const std::vector& GetLengths() const; + const std::vector& GetStrides() const; + + template + std::size_t GetOffsetFromMultiIndex(Is... is) const + { + assert(sizeof...(Is) == this->GetNumOfDimension()); + std::initializer_list iss{static_cast(is)...}; + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + private: + std::vector mLens; + std::vector mStrides; +}; + +struct joinable_thread : std::thread +{ + template + joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) + { + } + + joinable_thread(joinable_thread&&) = default; + joinable_thread& operator=(joinable_thread&&) = default; + + ~joinable_thread() + { + if(this->joinable()) + this->join(); + } +}; + +template +struct ParallelTensorFunctor +{ + F mF; + static constexpr std::size_t NDIM = sizeof...(Xs); + std::array mLens; + std::array mStrides; + std::size_t mN1d; + + ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast(xs)...}) + { + mStrides.back() = 1; + std::partial_sum(mLens.rbegin(), + mLens.rend() - 1, + mStrides.rbegin() + 1, + std::multiplies()); + mN1d = mStrides[0] * mLens[0]; + } + + std::array GetNdIndices(std::size_t i) const + { + std::array indices; + + for(int idim = 0; idim < NDIM; ++idim) + { + indices[idim] = i / mStrides[idim]; + i -= indices[idim] * mStrides[idim]; + } + + return indices; + } + + void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const + { + std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d); + + auto f = [=] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + call_f_unpack_args(mF, GetNdIndices(iw)); + } + }; + threads[it] = joinable_thread(f); + } + } +}; + +template +auto make_ParallelTensorFunctor(F f, Xs... xs) +{ + return ParallelTensorFunctor(f, xs...); +} + +template +struct Tensor +{ + template + Tensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.GetElementSpace()) + { + } + + template + Tensor(std::vector lens) : mDesc(lens), mData(mDesc.GetElementSpace()) + { + } + + template + Tensor(std::vector lens, std::vector strides) + : mDesc(lens, strides), mData(mDesc.GetElementSpace()) + { + } + + Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} + + template + void GenerateTensorValue(G g, std::size_t num_thread = 1) + { + switch(mDesc.GetNumOfDimension()) + { + case 1: { + auto f = [&](auto i) { (*this)(i) = g(i); }; + make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread); + break; + } + case 2: { + auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); }; + make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread); + break; + } + case 3: { + auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); }; + make_ParallelTensorFunctor( + f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread); + break; + } + case 4: { + auto f = [&](auto i0, auto i1, auto i2, auto i3) { + (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3])(num_thread); + break; + } + default: throw std::runtime_error("unspported dimension"); + } + } + + template + T& operator()(Is... is) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } + + template + const T& operator()(Is... is) const + { + return mData[mDesc.GetOffsetFromMultiIndex(is...)]; + } + + typename std::vector::iterator begin() { return mData.begin(); } + + typename std::vector::iterator end() { return mData.end(); } + + typename std::vector::const_iterator begin() const { return mData.begin(); } + + typename std::vector::const_iterator end() const { return mData.end(); } + + HostTensorDescriptor mDesc; + std::vector mData; +}; + +template +HostTensorDescriptor::HostTensorDescriptor(std::vector lens) : mLens(lens) +{ + this->CalculateStrides(); +} + +template +HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector strides) + : mLens(lens), mStrides(strides) +{ +} + +void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); + +template +void check_error(const Tensor& ref, const Tensor& result) +{ + float error = 0; + float max_diff = -1; + float ref_value = 0, result_value = 0; + for(int i = 0; i < ref.mData.size(); ++i) + { + error += std::abs(double(ref.mData[i]) - double(result.mData[i])); + float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); + if(max_diff < diff) + { + max_diff = diff; + ref_value = ref.mData[i]; + result_value = result.mData[i]; + } + } + + std::cout << "error: " << error << std::endl; + std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; +} + +#endif diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp new file mode 100644 index 0000000000..7c09843d01 --- /dev/null +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -0,0 +1,60 @@ +#ifndef HOST_TENSOR_GENERATOR_HPP +#define HOST_TENSOR_GENERATOR_HPP + +#include +#include "config.hpp" + +struct GeneratorTensor_1 +{ + int value = 1; + + template + float operator()(Is...) + { + return value; + } +}; + +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + float operator()(Is...) + { + return (std::rand() % (max_value - min_value)) + min_value; + } +}; + +template +struct GeneratorTensor_3 +{ + T min_value = 0; + T max_value = 1; + + template + float operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + return min_value + tmp * (max_value - min_value); + } +}; + +struct GeneratorTensor_Checkboard +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + return std::accumulate(dims.begin(), + dims.end(), + true, + [](bool init, ck::index_t x) -> int { return init != (x % 2); }) + ? 1 + : -1; + } +}; + +#endif diff --git a/host/host_tensor/src/device.cpp b/host/host_tensor/src/device.cpp new file mode 100644 index 0000000000..0d1b3d6883 --- /dev/null +++ b/host/host_tensor/src/device.cpp @@ -0,0 +1,67 @@ +#include "device.hpp" + +DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) +{ + hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + +void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } + +void DeviceMem::ToDevice(const void* p) +{ + hipGetErrorString( + hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); +} + +void DeviceMem::FromDevice(void* p) +{ + hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); +} + +DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } + +struct KernelTimerImpl +{ + KernelTimerImpl() + { + hipGetErrorString(hipEventCreate(&mStart)); + hipGetErrorString(hipEventCreate(&mEnd)); + } + + ~KernelTimerImpl() + { + hipGetErrorString(hipEventDestroy(mStart)); + hipGetErrorString(hipEventDestroy(mEnd)); + } + + void Start() + { + hipGetErrorString(hipDeviceSynchronize()); + hipGetErrorString(hipEventRecord(mStart, nullptr)); + } + + void End() + { + hipGetErrorString(hipEventRecord(mEnd, nullptr)); + hipGetErrorString(hipEventSynchronize(mEnd)); + } + + float GetElapsedTime() const + { + float time; + hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd)); + return time; + } + + hipEvent_t mStart, mEnd; +}; + +KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {} + +KernelTimer::~KernelTimer() {} + +void KernelTimer::Start() { impl->Start(); } + +void KernelTimer::End() { impl->End(); } + +float KernelTimer::GetElapsedTime() const { return impl->GetElapsedTime(); } diff --git a/host/host_tensor/src/host_tensor.cpp b/host/host_tensor/src/host_tensor.cpp new file mode 100644 index 0000000000..e840baf7f5 --- /dev/null +++ b/host/host_tensor/src/host_tensor.cpp @@ -0,0 +1,48 @@ +#include +#include + +#include "host_tensor.hpp" + +void HostTensorDescriptor::CalculateStrides() +{ + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum( + mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); +} + +std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } + +std::size_t HostTensorDescriptor::GetElementSize() const +{ + assert(mLens.size() == mStrides.size()); + return std::accumulate( + mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies()); +} + +std::size_t HostTensorDescriptor::GetElementSpace() const +{ + auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; }); + return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1; +} + +const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } + +const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + +void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) +{ + os << "dim " << desc.GetNumOfDimension() << ", "; + + os << "lengths {"; + LogRange(os, desc.GetLengths(), ", "); + os << "}, "; + + os << "strides {"; + LogRange(os, desc.GetStrides(), ", "); + os << "}" << std::endl; +} diff --git a/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..2b645e3c3b --- /dev/null +++ b/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,689 @@ +#ifndef CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_IGEMM_FWD_V6R1_DLOPS_NCHW_KCYX_NKHW_HPP + +#include +#include + +namespace ck { +namespace driver { + +struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + auto GetCompileParameterString() const + { + auto param = std::stringstream(); + + // clang-format off + param << + " -DCK_PARAM_ABDataTypeEnum=" << + ABDataTypeEnum << + " -DCK_PARAM_AccDataTypeEnum=" << + AccDataTypeEnum << + " -DCK_PARAM_CDataTypeEnum=" << + CDataTypeEnum << + " -DCK_PARAM_BlockSize=" << + BlockSize << + " -DCK_PARAM_GN0=" << + GN0 << + " -DCK_PARAM_GK1=" << + GK1 << + " -DCK_PARAM_GM1PerBlockGM11=" + << GM1PerBlockGM11 << + " -DCK_PARAM_GN1PerBlockGN11=" << + GN1PerBlockGN11 << + " -DCK_PARAM_GK0PerBlock=" << + GK0PerBlock << + " -DCK_PARAM_BM1PerThreadBM11=" << + BM1PerThreadBM11 << + " -DCK_PARAM_BN1PerThreadBN11=" << + BN1PerThreadBN11 << + " -DCK_PARAM_BK0PerThread=" << + BK0PerThread << + " -DCK_PARAM_BM10BN10ThreadClusterBM10Xs=" << + BM10BN10ThreadClusterBM10Xs[0] << "," << + BM10BN10ThreadClusterBM10Xs[1] << + " -DCK_PARAM_BM10BN10ThreadClusterBN10Xs=" << + BM10BN10ThreadClusterBN10Xs[0] << "," << + BM10BN10ThreadClusterBN10Xs[1] << + " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" << + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4] << + " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" << + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4] << + " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" << + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] << + " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" << + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0] << "," << + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1] << "," << + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2] << "," << + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3] << "," << + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4] << + " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" << + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4] << + " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" << + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4] << + " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" << + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] << + " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" << + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0] << "," << + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1] << "," << + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2] << "," << + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3] << "," << + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4] << + " -DCK_PARAM_CThreadTransferDstScalarPerVector=" << + CThreadTransferDstScalarPerVector << + " -DCK_PARAM_HasMainKBlockLoop=" << + static_cast(HasMainKBlockLoop) << + " -DCK_PARAM_HasDoubleTailKBlockLoop=" << + static_cast(HasDoubleTailKBlockLoop); + // clang-format on + + return param.str(); + } + + ck::DataTypeEnum_t ABDataTypeEnum = ck::DataTypeEnum_t::Unknown; + ck::DataTypeEnum_t AccDataTypeEnum = ck::DataTypeEnum_t::Unknown; + ck::DataTypeEnum_t CDataTypeEnum = ck::DataTypeEnum_t::Unknown; + + int BlockSize = -1; + + int GN0 = -1; + int GK1 = -1; + + int GM1PerBlockGM11 = -1; + int GN1PerBlockGN11 = -1; + int GK0PerBlock = -1; + + int BM1PerThreadBM11 = -1; + int BN1PerThreadBN11 = -1; + int BK0PerThread = -1; + + std::array BM10BN10ThreadClusterBM10Xs = {-1, -1}; + std::array BM10BN10ThreadClusterBN10Xs = {-1, -1}; + + std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = { + -1, -1, -1, -1, -1}; + std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = { + -1, -1, -1, -1, -1}; + std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { + -1, -1, -1, -1, -1}; + std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { + -1, -1, -1, -1, -1}; + + std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = { + -1, -1, -1, -1, -1}; + std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = { + -1, -1, -1, -1, -1}; + std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { + -1, -1, -1, -1, -1}; + std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { + -1, -1, -1, -1, -1}; + + int CThreadTransferDstScalarPerVector = -1; + + bool HasMainKBlockLoop = false; + bool HasDoubleTailKBlockLoop = false; +}; + +struct TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + ck::DataTypeEnum_t ABDataTypeEnum; + ck::DataTypeEnum_t CDataTypeEnum; + + int BlockSize; + + int GN0; + int GK1; + + int GM1PerBlockGM11; + int GN1PerBlockGN11; + int GK0PerBlock; + + int BM1PerThreadBM11; + int BN1PerThreadBN11; + int BK0PerThread; + + std::array BM10BN10ThreadClusterBM10Xs; + std::array BM10BN10ThreadClusterBN10Xs; + + std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; +}; + +inline static auto generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw() +{ + constexpr auto f32 = ck::DataTypeEnum_t::Float; + constexpr auto f16 = ck::DataTypeEnum_t::Half; + constexpr auto i8 = ck::DataTypeEnum_t::Int8; + + return std::vector{ + // clang-format off + // fp32 + {f32, f32, 256, 1, 1, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 1}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + {f32, f32, 256, 1, 1, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 1}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f32, f32, 256, 2, 1, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 1}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f32, f32, 256, 4, 1, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 1}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f32, f32, 256, 8, 1, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 1}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 1}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f32, f32, 128, 1, 1, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 1}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 1}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + // fp16 + {f16, f16, 256, 1, 2, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 2}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + {f16, f16, 256, 1, 2, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 2}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f16, f16, 256, 2, 2, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 2}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + {f16, f16, 256, 4, 2, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 2}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f16, f16, 256, 8, 2, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 2}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 2}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + {f16, f16, 128, 1, 2, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 2}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 2}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + // i8 + { i8, i8, 256, 1, 4, 128, 128, 16, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 2, 4}, {4, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 4, 1}, {1, 1, 1, 4, 1}}, + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 2, 1}, {1, 1, 1, 4, 1}}, + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 4, 1}}, + + { i8, i8, 256, 1, 4, 128, 128, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {4, 1, 1, 1, 4}, { 2, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + { i8, i8, 256, 2, 4, 128, 64, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 2, 1, 1, 4}, { 4, 1, 1, 64, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + { i8, i8, 256, 4, 4, 128, 32, 8, 4, 4, 1, {8, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 4, 1, 1, 4}, { 8, 1, 1, 32, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + { i8, i8, 256, 8, 4, 128, 16, 16, 4, 4, 1, {8, 2}, {8, 2}, {8, 1, 1, 1, 4}, {2, 1, 1, 128, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 8, 1, 1, 4}, {16, 1, 1, 16, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}}, + + { i8, i8, 128, 1, 4, 64, 128, 8, 4, 4, 1, {4, 2}, {8, 2}, {4, 1, 1, 1, 4}, {2, 1, 1, 64, 1}, {4, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {8, 1, 1, 1, 4}, { 1, 1, 1, 128, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}} + // clang-format on + }; +} + +// TODO make this common interface and write specs for it +struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw +{ + static auto + CalculateCompileParameterBasedOnTunable(const ConvolutionProblemDescriptor& conv_problem_desc, + const TunableConvIgemmFwdV6r1DlopsNchwKcyxNkhw& tunable) + { + const int C = conv_problem_desc.C; + const int Y = conv_problem_desc.Y; + const int X = conv_problem_desc.X; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + if(!(conv_problem_desc.InDataTypeEnum == tunable.ABDataTypeEnum && + conv_problem_desc.WeiDataTypeEnum == tunable.ABDataTypeEnum && + conv_problem_desc.OutDataTypeEnum == tunable.CDataTypeEnum)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const auto ABDataTypeEnum = conv_problem_desc.InDataTypeEnum; + const auto CDataTypeEnum = conv_problem_desc.OutDataTypeEnum; + + DataTypeEnum_t AccDataTypeEnum; + + if(ABDataTypeEnum == DataTypeEnum_t::Float || ABDataTypeEnum == DataTypeEnum_t::Half) + { + AccDataTypeEnum = DataTypeEnum_t::Float; + } + else if(ABDataTypeEnum == DataTypeEnum_t::Int8) + { + AccDataTypeEnum = DataTypeEnum_t::Int32; + } + else + { + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + } + + const int BlockSize = tunable.BlockSize; + + const int GN0 = tunable.GN0; + const int GK1 = tunable.GK1; + + const int GM11 = tunable.GM1PerBlockGM11; + const int GN11 = tunable.GN1PerBlockGN11; + const int GK0PerBlock = tunable.GK0PerBlock; + + const int BM11 = tunable.BM1PerThreadBM11; + const int BN11 = tunable.BN1PerThreadBN11; + const int BK0PerThread = tunable.BK0PerThread; + + const auto BM10BN10ThreadClusterBM10Xs = tunable.BM10BN10ThreadClusterBM10Xs; + const auto BM10BN10ThreadClusterBN10Xs = tunable.BM10BN10ThreadClusterBN10Xs; + + const auto ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + const auto ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + const auto BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + const auto BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + + // C threadwise copy: {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim + const int CThreadTransferDstScalarPerVector = gcd(4, GN11, BN11, Ho * Wo); + + const int C0 = GK1; + + if(!(C % C0 == 0)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const int C1 = C / C0; + + const int GK0 = C1 * Y * X; + + if(!(GK0 % GK0PerBlock == 0)) + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + + const bool HasMainKBlockLoop = ((GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1); + + const bool HasDoubleTailKBlockLoop = ((GK0 / GK0PerBlock) % 2 == 0); + + return std::make_tuple( + CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{ + ABDataTypeEnum, + AccDataTypeEnum, + CDataTypeEnum, + BlockSize, + GN0, + GK1, + GM11, + GN11, + GK0PerBlock, + BM11, + BN11, + BK0PerThread, + BM10BN10ThreadClusterBM10Xs, + BM10BN10ThreadClusterBN10Xs, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + CThreadTransferDstScalarPerVector, + HasMainKBlockLoop, + HasDoubleTailKBlockLoop}, + true); + } + + static auto GetDefaultCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc) + { + for(const auto& tunable : generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw()) + { + CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param{}; + bool found = false; + + std::tie(compile_param, found) = + CalculateCompileParameterBasedOnTunable(conv_problem_desc, tunable); + + if(found && IsValidCompileParameter(conv_problem_desc, compile_param)) + return std::make_tuple(compile_param, true); + } + + return std::make_tuple(CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw{}, false); + } + + static bool IsApplicable(const ConvolutionProblemDescriptor& conv_problem_desc) + { + bool found = false; + + std::tie(std::ignore, found) = GetDefaultCompileParameter(conv_problem_desc); + + return found; + } + + static bool + IsValidCompileParameter(const ConvolutionProblemDescriptor& conv_problem_desc, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + const int N = conv_problem_desc.N; + const int K = conv_problem_desc.K; + const int C = conv_problem_desc.C; + const int Y = conv_problem_desc.Y; + const int X = conv_problem_desc.X; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + const int GK1 = compile_param.GK1; + const int GN0 = compile_param.GN0; + const int GM11 = compile_param.GM1PerBlockGM11; + const int GN11 = compile_param.GN1PerBlockGN11; + + const int BM11 = compile_param.BM1PerThreadBM11; + const int BN11 = compile_param.BN1PerThreadBN11; + + const int C0 = GK1; + const int N0 = GN0; + + if(!(C % C0 == 0)) + return false; + + const int C1 = C / C0; + + if(!(N % N0 == 0)) + return false; + + const int N1 = N / N0; + + const int GM0 = 1; + const int GM1 = K; + const int GN1 = N1 * Ho * Wo; + const int GK0 = C1 * Y * X; + + // check data type + { + if(!(conv_problem_desc.InDataTypeEnum == conv_problem_desc.WeiDataTypeEnum && + conv_problem_desc.InDataTypeEnum == compile_param.ABDataTypeEnum)) + return false; + + if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Float || + compile_param.ABDataTypeEnum == DataTypeEnum_t::Half) + { + if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Float)) + return false; + } + else if(compile_param.ABDataTypeEnum == DataTypeEnum_t::Int8) + { + if(!(compile_param.AccDataTypeEnum == DataTypeEnum_t::Int32)) + return false; + } + } + + // check gridwise contraction + { + if(!(GM1 % GM11 == 0 && GN1 % GN11 == 0 && GK0 % compile_param.GK0PerBlock == 0)) + return false; + + const bool has_main_k_block_loop = + ((GK0 + compile_param.GK0PerBlock) / (2 * compile_param.GK0PerBlock) > 1); + + const bool has_double_tail_k_block_loop = ((GK0 / compile_param.GK0PerBlock) % 2 == 0); + + if(!(has_main_k_block_loop == compile_param.HasMainKBlockLoop && + has_double_tail_k_block_loop == compile_param.HasDoubleTailKBlockLoop)) + return false; + } + + // check A blockwise copy + { + const auto block_slice_lengths = + std::array{compile_param.GK0PerBlock, GM0, 1, GM11, GK1}; + const auto& cluster_lengths = + compile_param.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1; + const auto& thread_slice_lengths = + compile_param.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1; + const auto& src_vector_lengths = + compile_param.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + const auto& dst_vector_lengths = + compile_param.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1; + + // check number of working thread + const int num_work_thread = std::accumulate( + cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); + + if(!(compile_param.BlockSize >= num_work_thread)) + return false; + + // check block slice lengths vs thread slice lengths vs cluster lengths + for(int i = 0; i < 5; ++i) + { + if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) + return false; + } + + // check thread slice lengths vs vector lengths + for(int i = 0; i < 5; ++i) + { + if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0)) + return false; + + if(!(thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) + return false; + } + + // check Src vectorization, GK0 is global mem vector dim + if(!(src_vector_lengths[1] == 1 && src_vector_lengths[2] == 1 && + src_vector_lengths[3] == 1 && src_vector_lengths[4] == 1)) + return false; + + // check Dst vectorization, {GM11, GK1} are LDS vector dims + if(dst_vector_lengths[4] == GK1) + { // vectorize on {GM11, GK1} + if(!(GM11 % dst_vector_lengths[3] == 0)) + return false; + } + else + { // vectorize on {GK1} only + if(!(GK1 % dst_vector_lengths[4] == 0)) + return false; + + if(!(dst_vector_lengths[3] == 1)) + return false; + } + } + + // check B blockwise copy + { + const auto block_slice_lengths = + std::array{compile_param.GK0PerBlock, GN0, 1, GN11, GK1}; + const auto& cluster_lengths = + compile_param.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1; + const auto& thread_slice_lengths = + compile_param.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1; + const auto& src_vector_lengths = + compile_param.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + const auto& dst_vector_lengths = + compile_param.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1; + + // check number of working thread + const int num_work_thread = std::accumulate( + cluster_lengths.begin(), cluster_lengths.end(), 1, std::multiplies{}); + + if(!(compile_param.BlockSize >= num_work_thread)) + return false; + + // check block slice lengths vs thread slice lengths vs cluster lengths + for(int i = 0; i < 5; ++i) + { + if(!(cluster_lengths[i] * thread_slice_lengths[i] == block_slice_lengths[i])) + return false; + } + + // check thread slice lengths vs vector lengths + for(int i = 0; i < 5; ++i) + { + if(!(thread_slice_lengths[i] % src_vector_lengths[i] == 0 && + thread_slice_lengths[i] % dst_vector_lengths[i] == 0)) + return false; + } + + // check Src vectorization: {GN11} is global mem vector dim + if(!(src_vector_lengths[0] == 1 && src_vector_lengths[1] == 1 && + src_vector_lengths[2] == 1 && src_vector_lengths[4] == 1)) + return false; + + // check Src tensor layout related vectorization + if(Y == 1 && X == 1 && conv_problem_desc.ConvStrideH == 1 && + conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadH == 0 && + conv_problem_desc.InLeftPadW == 0 && conv_problem_desc.InRightPadH == 0 && + conv_problem_desc.InRightPadW == 0) + { + if(!((Ho * Wo) % src_vector_lengths[3] == 0)) + return false; + } + else if(conv_problem_desc.ConvStrideW == 1 && conv_problem_desc.InLeftPadW == 0 && + conv_problem_desc.InRightPadW == 0) + { + if(!(Wo % src_vector_lengths[3] == 0)) + return false; + } + else + { + if(!(src_vector_lengths[3] == 1)) + return false; + } + + // check Dst vectorization: {GN11, GK1} are LDS vector dims + if(dst_vector_lengths[4] == GK1) + { // vectorize on {GN11, GK1} + if(!(GN11 % dst_vector_lengths[3] == 0)) + return false; + } + else + { // vectorize on {GK1} only + if(!(dst_vector_lengths[3] == 1)) + return false; + + if(!(GK1 % dst_vector_lengths[4] == 0)) + return false; + } + } + + // check blockwise GEMM + { + const int BM10 = std::accumulate(compile_param.BM10BN10ThreadClusterBM10Xs.begin(), + compile_param.BM10BN10ThreadClusterBM10Xs.end(), + 1, + std::multiplies{}); + + const int BN10 = std::accumulate(compile_param.BM10BN10ThreadClusterBN10Xs.begin(), + compile_param.BM10BN10ThreadClusterBN10Xs.end(), + 1, + std::multiplies{}); + + if(!(compile_param.BlockSize == BM10 * BN10)) + return false; + + const int BM = GM0 * GM11; + const int BN = GN0 * GN11; + + const int BM1 = BM10 * BM11; + const int BN1 = BN10 * BN11; + + if(!(BM % BM1 == 0 && BN % BN1 == 0)) + return false; + + const int BM0 = BM / BM1; + const int BN0 = BN / BN1; + + // blockwise GEMM currently only support BM0 == 2 && BN0 == 2 + if(!(BM0 == 2 && BN0 == 2)) + return false; + + if(!(compile_param.GK0PerBlock % compile_param.BK0PerThread == 0)) + return false; + } + + // check C threadwise copy + { + // {BN11} or {BN} or {BN1} or {GN11} is Dst vector dim + const int dst_vector_len_gn11 = compile_param.CThreadTransferDstScalarPerVector; + + // check slice length vs Dst vector length: + if(!(BN11 % dst_vector_len_gn11 == 0 && GN11 % dst_vector_len_gn11 == 0)) + return false; + + // check Dst memory layout related vectorization: + if(!((Ho * Wo) % compile_param.CThreadTransferDstScalarPerVector == 0)) + return false; + } + + return true; + }; + + static int GetBlockSize(const ConvolutionProblemDescriptor&, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + return compile_param.BlockSize; + } + + static int GetGridSize(const ConvolutionProblemDescriptor& conv_problem_desc, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param) + { + const int N = conv_problem_desc.N; + const int K = conv_problem_desc.K; + const int Ho = conv_problem_desc.Ho; + const int Wo = conv_problem_desc.Wo; + + const int N0 = compile_param.GN0; + const int N1 = N / N0; + + const int GM1 = K; + const int GN1 = N1 * Ho * Wo; + + const int GM11 = compile_param.GM1PerBlockGM11; + const int GN11 = compile_param.GN1PerBlockGN11; + + const int GM10 = GM1 / GM11; + const int GN10 = GN1 / GN11; + + return GM10 * GN10; + } + + static std::size_t GetWorkSpaceSize(const ConvolutionProblemDescriptor&, + const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw&) + { + // workspace is used for save transformed tensor descritpors created by prepare kernel + return 4096L; + } + + static std::size_t GetMaxWorkSpaceSize(const ConvolutionProblemDescriptor&) { return 4096L; } + + static auto GetTunableList() + { + return generate_tunable_list_conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw(); + } +}; + +} // namespace driver +} // namespace ck +#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..58fe588ad9 --- /dev/null +++ b/host/solver/include/conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,51 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_TUNABLE_FWD_V4R4_DLOPS_NCHW_KCYX_NKHW_HPP + +struct tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw +{ + int BlockSize; + + int MPerBlock; + int NPerBlock; + int KPerBlock; + + int M1PerThread; + int N1PerThread; + int KPerThread; + + int M1N1ThreadClusterM10; + int M1N1ThreadClusterN10; + int M1N1ThreadClusterM11; + int M1N1ThreadClusterN11; + + std::array ABlockTransferThreadSliceLengths_K_M0_M1; + std::array ABlockTransferThreadClusterLengths_K_M0_M1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_M1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K_N0_N1; + std::array BBlockTransferThreadClusterLengths_K_N0_N1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_N1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw + default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw = { + 256, 128, 128, 8, 4, 4, 1, + 8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0}, + {2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128}, + {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, + 5, 1}; +#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..97ce326346 --- /dev/null +++ b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,73 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP + +struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw +{ + int BlockSize; + + int MPerBlock; + int NPerBlock; + int KPerBlock; + + int MPerWave; + int NPerWave; + int K1; + + int MRepeat; + int NRepeat; + + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_K1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_K1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw + default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = { + 256, // BlockSize + 128, // MPerBlock, + 128, // NPerBlock, + 4, // KPerBlock, + 32, // MPerWave, + 32, // NPerWave, + 4, // K1, + 2, // MRepeat, + 2, // NRepeat, + {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, + {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, + {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector, + 4, // ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, + {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, + {0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // BBlockTransferSrcAccessOrder, + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + false, // BThreadTransferSrcResetCoordinateAfterRun + {3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder + 7, // CThreadTransferSrcDstVectorDim, + 1 // CThreadTransferDstScalarPerVector +}; +#endif diff --git a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..263c21a13b --- /dev/null +++ b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,73 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP +#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP + +struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk +{ + int BlockSize; + + int MPerBlock; + int NPerBlock; + int KPerBlock; + + int MPerWave; + int NPerWave; + int K1; + + int MRepeat; + int NRepeat; + + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int ABlockTransferSrcVectorDim; + int ABlockTransferSrcScalarPerVector; + int ABlockTransferDstScalarPerVector_K1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int BBlockTransferSrcVectorDim; + int BBlockTransferSrcScalarPerVector; + int BBlockTransferDstScalarPerVector_K1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int CThreadTransferSrcDstVectorDim; + int CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk + default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = { + 256, // BlockSize + 128, // MPerBlock, + 128, // NPerBlock, + 4, // KPerBlock, + 32, // MPerWave, + 32, // NPerWave, + 4, // K1, + 2, // MRepeat, + 2, // NRepeat, + {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, + {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, + {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector, + 4, // ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, + {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, + {1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + false, // BThreadTransferSrcResetCoordinateAfterRun + {2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder + 7, // CThreadTransferSrcDstVectorDim, + 1 // CThreadTransferDstScalarPerVector +}; +#endif diff --git a/host/solver/include/convolution_problem_descriptor.hpp b/host/solver/include/convolution_problem_descriptor.hpp new file mode 100644 index 0000000000..8c0ecbee80 --- /dev/null +++ b/host/solver/include/convolution_problem_descriptor.hpp @@ -0,0 +1,81 @@ +#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR +#define CONVOLUTION_PROBLEM_DESCRIPTOR + +namespace ck { +namespace driver { + +struct ConvolutionProblemDescriptor +{ + ConvolutionProblemDescriptor() = default; + + ConvolutionProblemDescriptor(int N_, + int K_, + int C_, + int Y_, + int X_, + int Hi_, + int Wi_, + int Ho_, + int Wo_, + int ConvStrideH_, + int ConvStrideW_, + int ConvDilationH_, + int ConvDilationW_, + int InLeftPadH_, + int InLeftPadW_, + int InRightPadH_, + int InRightPadW_, + ck::DataTypeEnum_t InDataTypeEnum_, + ck::DataTypeEnum_t WeiDataTypeEnum_, + ck::DataTypeEnum_t OutDataTypeEnum_) + : N{N_}, + K{K_}, + C{C_}, + Y{Y_}, + X{X_}, + Hi{Hi_}, + Wi{Wi_}, + Ho{Ho_}, + Wo{Wo_}, + ConvStrideH{ConvStrideH_}, + ConvStrideW{ConvStrideW_}, + ConvDilationH{ConvDilationH_}, + ConvDilationW{ConvDilationW_}, + InLeftPadH{InLeftPadH_}, + InLeftPadW{InLeftPadW_}, + InRightPadH{InRightPadH_}, + InRightPadW{InRightPadW_}, + InDataTypeEnum{InDataTypeEnum_}, + WeiDataTypeEnum{WeiDataTypeEnum_}, + OutDataTypeEnum{OutDataTypeEnum_} + { + } + + int N; + int K; + int C; + int Y; + int X; + int Hi; + int Wi; + int Ho; + int Wo; + int ConvStrideH; + int ConvStrideW; + int ConvDilationH; + int ConvDilationW; + int InLeftPadH; + int InLeftPadW; + int InRightPadH; + int InRightPadW; + + ck::DataTypeEnum_t InDataTypeEnum; + ck::DataTypeEnum_t WeiDataTypeEnum; + ck::DataTypeEnum_t OutDataTypeEnum; + + std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; } +}; + +} // namespace driver +} // namespace ck +#endif diff --git a/host/solver/include/solver_common.hpp b/host/solver/include/solver_common.hpp new file mode 100644 index 0000000000..d1792f7681 --- /dev/null +++ b/host/solver/include/solver_common.hpp @@ -0,0 +1,46 @@ +#ifndef CK_SOLVER_COMMON_HPP +#define CK_SOLVER_COMMON_HPP + +namespace ck { +namespace driver { + +// greatest common divisor, aka highest common factor +inline int gcd(int x, int y) +{ + if(x < 0) + { + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x % y, y); + } + else + { + return gcd(x, y % x); + } +} + +template = 2, bool>::type = false> +auto gcd(X x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} + +} // namespace driver +} // namespace ck +#endif diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh new file mode 100755 index 0000000000..ebfa2b9f69 --- /dev/null +++ b/script/cmake-rocm.sh @@ -0,0 +1,18 @@ +#!/bin/bash +rm -f CMakeCache.txt +rm -f *.cmake +rm -rf CMakeFiles + +MY_PROJECT_SOURCE=../../.. +MY_PROJECT_INSTALL=../install.dir + +cmake \ +-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ +-D HALF_INCLUDE_DIR="/root/workspace/external/half/include" \ +-D BUILD_DEV=ON \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +${MY_PROJECT_SOURCE} diff --git a/script/count_vgpr.sh b/script/count_vgpr.sh new file mode 100755 index 0000000000..4fbfec0278 --- /dev/null +++ b/script/count_vgpr.sh @@ -0,0 +1,259 @@ +#!/bin/bash +FILE=$1 + +echo v0 $( grep -w v0 $FILE | wc -l ) +echo v1 $( grep -w v1 $FILE | wc -l ) +echo v2 $( grep -w v2 $FILE | wc -l ) +echo v3 $( grep -w v3 $FILE | wc -l ) +echo v4 $( grep -w v4 $FILE | wc -l ) +echo v5 $( grep -w v5 $FILE | wc -l ) +echo v6 $( grep -w v6 $FILE | wc -l ) +echo v7 $( grep -w v7 $FILE | wc -l ) +echo v8 $( grep -w v8 $FILE | wc -l ) +echo v9 $( grep -w v9 $FILE | wc -l ) +echo v10 $( grep -w v10 $FILE | wc -l ) +echo v11 $( grep -w v11 $FILE | wc -l ) +echo v12 $( grep -w v12 $FILE | wc -l ) +echo v13 $( grep -w v13 $FILE | wc -l ) +echo v14 $( grep -w v14 $FILE | wc -l ) +echo v15 $( grep -w v15 $FILE | wc -l ) +echo v16 $( grep -w v16 $FILE | wc -l ) +echo v17 $( grep -w v17 $FILE | wc -l ) +echo v18 $( grep -w v18 $FILE | wc -l ) +echo v19 $( grep -w v19 $FILE | wc -l ) +echo v20 $( grep -w v20 $FILE | wc -l ) +echo v21 $( grep -w v21 $FILE | wc -l ) +echo v22 $( grep -w v22 $FILE | wc -l ) +echo v23 $( grep -w v23 $FILE | wc -l ) +echo v24 $( grep -w v24 $FILE | wc -l ) +echo v25 $( grep -w v25 $FILE | wc -l ) +echo v26 $( grep -w v26 $FILE | wc -l ) +echo v27 $( grep -w v27 $FILE | wc -l ) +echo v28 $( grep -w v28 $FILE | wc -l ) +echo v29 $( grep -w v29 $FILE | wc -l ) +echo v30 $( grep -w v30 $FILE | wc -l ) +echo v31 $( grep -w v31 $FILE | wc -l ) +echo v32 $( grep -w v32 $FILE | wc -l ) +echo v33 $( grep -w v33 $FILE | wc -l ) +echo v34 $( grep -w v34 $FILE | wc -l ) +echo v35 $( grep -w v35 $FILE | wc -l ) +echo v36 $( grep -w v36 $FILE | wc -l ) +echo v37 $( grep -w v37 $FILE | wc -l ) +echo v38 $( grep -w v38 $FILE | wc -l ) +echo v39 $( grep -w v39 $FILE | wc -l ) +echo v40 $( grep -w v40 $FILE | wc -l ) +echo v41 $( grep -w v41 $FILE | wc -l ) +echo v42 $( grep -w v42 $FILE | wc -l ) +echo v43 $( grep -w v43 $FILE | wc -l ) +echo v44 $( grep -w v44 $FILE | wc -l ) +echo v45 $( grep -w v45 $FILE | wc -l ) +echo v46 $( grep -w v46 $FILE | wc -l ) +echo v47 $( grep -w v47 $FILE | wc -l ) +echo v48 $( grep -w v48 $FILE | wc -l ) +echo v49 $( grep -w v49 $FILE | wc -l ) +echo v50 $( grep -w v50 $FILE | wc -l ) +echo v51 $( grep -w v51 $FILE | wc -l ) +echo v52 $( grep -w v52 $FILE | wc -l ) +echo v53 $( grep -w v53 $FILE | wc -l ) +echo v54 $( grep -w v54 $FILE | wc -l ) +echo v55 $( grep -w v55 $FILE | wc -l ) +echo v56 $( grep -w v56 $FILE | wc -l ) +echo v57 $( grep -w v57 $FILE | wc -l ) +echo v58 $( grep -w v58 $FILE | wc -l ) +echo v59 $( grep -w v59 $FILE | wc -l ) +echo v60 $( grep -w v60 $FILE | wc -l ) +echo v61 $( grep -w v61 $FILE | wc -l ) +echo v62 $( grep -w v62 $FILE | wc -l ) +echo v63 $( grep -w v63 $FILE | wc -l ) +echo v64 $( grep -w v64 $FILE | wc -l ) +echo v65 $( grep -w v65 $FILE | wc -l ) +echo v66 $( grep -w v66 $FILE | wc -l ) +echo v67 $( grep -w v67 $FILE | wc -l ) +echo v68 $( grep -w v68 $FILE | wc -l ) +echo v69 $( grep -w v69 $FILE | wc -l ) +echo v70 $( grep -w v70 $FILE | wc -l ) +echo v71 $( grep -w v71 $FILE | wc -l ) +echo v72 $( grep -w v72 $FILE | wc -l ) +echo v73 $( grep -w v73 $FILE | wc -l ) +echo v74 $( grep -w v74 $FILE | wc -l ) +echo v75 $( grep -w v75 $FILE | wc -l ) +echo v76 $( grep -w v76 $FILE | wc -l ) +echo v77 $( grep -w v77 $FILE | wc -l ) +echo v78 $( grep -w v78 $FILE | wc -l ) +echo v79 $( grep -w v79 $FILE | wc -l ) +echo v80 $( grep -w v80 $FILE | wc -l ) +echo v81 $( grep -w v81 $FILE | wc -l ) +echo v82 $( grep -w v82 $FILE | wc -l ) +echo v83 $( grep -w v83 $FILE | wc -l ) +echo v84 $( grep -w v84 $FILE | wc -l ) +echo v85 $( grep -w v85 $FILE | wc -l ) +echo v86 $( grep -w v86 $FILE | wc -l ) +echo v87 $( grep -w v87 $FILE | wc -l ) +echo v88 $( grep -w v88 $FILE | wc -l ) +echo v89 $( grep -w v89 $FILE | wc -l ) +echo v90 $( grep -w v90 $FILE | wc -l ) +echo v91 $( grep -w v91 $FILE | wc -l ) +echo v92 $( grep -w v92 $FILE | wc -l ) +echo v93 $( grep -w v93 $FILE | wc -l ) +echo v94 $( grep -w v94 $FILE | wc -l ) +echo v95 $( grep -w v95 $FILE | wc -l ) +echo v96 $( grep -w v96 $FILE | wc -l ) +echo v97 $( grep -w v97 $FILE | wc -l ) +echo v98 $( grep -w v98 $FILE | wc -l ) +echo v99 $( grep -w v99 $FILE | wc -l ) +echo v100 $( grep -w v100 $FILE | wc -l ) +echo v101 $( grep -w v101 $FILE | wc -l ) +echo v102 $( grep -w v102 $FILE | wc -l ) +echo v103 $( grep -w v103 $FILE | wc -l ) +echo v104 $( grep -w v104 $FILE | wc -l ) +echo v105 $( grep -w v105 $FILE | wc -l ) +echo v106 $( grep -w v106 $FILE | wc -l ) +echo v107 $( grep -w v107 $FILE | wc -l ) +echo v108 $( grep -w v108 $FILE | wc -l ) +echo v109 $( grep -w v109 $FILE | wc -l ) +echo v110 $( grep -w v110 $FILE | wc -l ) +echo v111 $( grep -w v111 $FILE | wc -l ) +echo v112 $( grep -w v112 $FILE | wc -l ) +echo v113 $( grep -w v113 $FILE | wc -l ) +echo v114 $( grep -w v114 $FILE | wc -l ) +echo v115 $( grep -w v115 $FILE | wc -l ) +echo v116 $( grep -w v116 $FILE | wc -l ) +echo v117 $( grep -w v117 $FILE | wc -l ) +echo v118 $( grep -w v118 $FILE | wc -l ) +echo v119 $( grep -w v119 $FILE | wc -l ) +echo v120 $( grep -w v120 $FILE | wc -l ) +echo v121 $( grep -w v121 $FILE | wc -l ) +echo v122 $( grep -w v122 $FILE | wc -l ) +echo v123 $( grep -w v123 $FILE | wc -l ) +echo v124 $( grep -w v124 $FILE | wc -l ) +echo v125 $( grep -w v125 $FILE | wc -l ) +echo v126 $( grep -w v126 $FILE | wc -l ) +echo v127 $( grep -w v127 $FILE | wc -l ) +echo v128 $( grep -w v128 $FILE | wc -l ) +echo v129 $( grep -w v129 $FILE | wc -l ) +echo v130 $( grep -w v130 $FILE | wc -l ) +echo v131 $( grep -w v131 $FILE | wc -l ) +echo v132 $( grep -w v132 $FILE | wc -l ) +echo v133 $( grep -w v133 $FILE | wc -l ) +echo v134 $( grep -w v134 $FILE | wc -l ) +echo v135 $( grep -w v135 $FILE | wc -l ) +echo v136 $( grep -w v136 $FILE | wc -l ) +echo v137 $( grep -w v137 $FILE | wc -l ) +echo v138 $( grep -w v138 $FILE | wc -l ) +echo v139 $( grep -w v139 $FILE | wc -l ) +echo v140 $( grep -w v140 $FILE | wc -l ) +echo v141 $( grep -w v141 $FILE | wc -l ) +echo v142 $( grep -w v142 $FILE | wc -l ) +echo v143 $( grep -w v143 $FILE | wc -l ) +echo v144 $( grep -w v144 $FILE | wc -l ) +echo v145 $( grep -w v145 $FILE | wc -l ) +echo v146 $( grep -w v146 $FILE | wc -l ) +echo v147 $( grep -w v147 $FILE | wc -l ) +echo v148 $( grep -w v148 $FILE | wc -l ) +echo v149 $( grep -w v149 $FILE | wc -l ) +echo v150 $( grep -w v150 $FILE | wc -l ) +echo v151 $( grep -w v151 $FILE | wc -l ) +echo v152 $( grep -w v152 $FILE | wc -l ) +echo v153 $( grep -w v153 $FILE | wc -l ) +echo v154 $( grep -w v154 $FILE | wc -l ) +echo v155 $( grep -w v155 $FILE | wc -l ) +echo v156 $( grep -w v156 $FILE | wc -l ) +echo v157 $( grep -w v157 $FILE | wc -l ) +echo v158 $( grep -w v158 $FILE | wc -l ) +echo v159 $( grep -w v159 $FILE | wc -l ) +echo v160 $( grep -w v160 $FILE | wc -l ) +echo v161 $( grep -w v161 $FILE | wc -l ) +echo v162 $( grep -w v162 $FILE | wc -l ) +echo v163 $( grep -w v163 $FILE | wc -l ) +echo v164 $( grep -w v164 $FILE | wc -l ) +echo v165 $( grep -w v165 $FILE | wc -l ) +echo v166 $( grep -w v166 $FILE | wc -l ) +echo v167 $( grep -w v167 $FILE | wc -l ) +echo v168 $( grep -w v168 $FILE | wc -l ) +echo v169 $( grep -w v169 $FILE | wc -l ) +echo v170 $( grep -w v170 $FILE | wc -l ) +echo v171 $( grep -w v171 $FILE | wc -l ) +echo v172 $( grep -w v172 $FILE | wc -l ) +echo v173 $( grep -w v173 $FILE | wc -l ) +echo v174 $( grep -w v174 $FILE | wc -l ) +echo v175 $( grep -w v175 $FILE | wc -l ) +echo v176 $( grep -w v176 $FILE | wc -l ) +echo v177 $( grep -w v177 $FILE | wc -l ) +echo v178 $( grep -w v178 $FILE | wc -l ) +echo v179 $( grep -w v179 $FILE | wc -l ) +echo v180 $( grep -w v180 $FILE | wc -l ) +echo v181 $( grep -w v181 $FILE | wc -l ) +echo v182 $( grep -w v182 $FILE | wc -l ) +echo v183 $( grep -w v183 $FILE | wc -l ) +echo v184 $( grep -w v184 $FILE | wc -l ) +echo v185 $( grep -w v185 $FILE | wc -l ) +echo v186 $( grep -w v186 $FILE | wc -l ) +echo v187 $( grep -w v187 $FILE | wc -l ) +echo v188 $( grep -w v188 $FILE | wc -l ) +echo v189 $( grep -w v189 $FILE | wc -l ) +echo v190 $( grep -w v190 $FILE | wc -l ) +echo v191 $( grep -w v191 $FILE | wc -l ) +echo v192 $( grep -w v192 $FILE | wc -l ) +echo v193 $( grep -w v193 $FILE | wc -l ) +echo v194 $( grep -w v194 $FILE | wc -l ) +echo v195 $( grep -w v195 $FILE | wc -l ) +echo v196 $( grep -w v196 $FILE | wc -l ) +echo v197 $( grep -w v197 $FILE | wc -l ) +echo v198 $( grep -w v198 $FILE | wc -l ) +echo v199 $( grep -w v199 $FILE | wc -l ) +echo v200 $( grep -w v200 $FILE | wc -l ) +echo v201 $( grep -w v201 $FILE | wc -l ) +echo v202 $( grep -w v202 $FILE | wc -l ) +echo v203 $( grep -w v203 $FILE | wc -l ) +echo v204 $( grep -w v204 $FILE | wc -l ) +echo v205 $( grep -w v205 $FILE | wc -l ) +echo v206 $( grep -w v206 $FILE | wc -l ) +echo v207 $( grep -w v207 $FILE | wc -l ) +echo v208 $( grep -w v208 $FILE | wc -l ) +echo v209 $( grep -w v209 $FILE | wc -l ) +echo v210 $( grep -w v210 $FILE | wc -l ) +echo v211 $( grep -w v211 $FILE | wc -l ) +echo v212 $( grep -w v212 $FILE | wc -l ) +echo v213 $( grep -w v213 $FILE | wc -l ) +echo v214 $( grep -w v214 $FILE | wc -l ) +echo v215 $( grep -w v215 $FILE | wc -l ) +echo v216 $( grep -w v216 $FILE | wc -l ) +echo v217 $( grep -w v217 $FILE | wc -l ) +echo v218 $( grep -w v218 $FILE | wc -l ) +echo v219 $( grep -w v219 $FILE | wc -l ) +echo v220 $( grep -w v220 $FILE | wc -l ) +echo v221 $( grep -w v221 $FILE | wc -l ) +echo v222 $( grep -w v222 $FILE | wc -l ) +echo v223 $( grep -w v223 $FILE | wc -l ) +echo v224 $( grep -w v224 $FILE | wc -l ) +echo v225 $( grep -w v225 $FILE | wc -l ) +echo v226 $( grep -w v226 $FILE | wc -l ) +echo v227 $( grep -w v227 $FILE | wc -l ) +echo v228 $( grep -w v228 $FILE | wc -l ) +echo v229 $( grep -w v229 $FILE | wc -l ) +echo v230 $( grep -w v230 $FILE | wc -l ) +echo v231 $( grep -w v231 $FILE | wc -l ) +echo v232 $( grep -w v232 $FILE | wc -l ) +echo v233 $( grep -w v233 $FILE | wc -l ) +echo v234 $( grep -w v234 $FILE | wc -l ) +echo v235 $( grep -w v235 $FILE | wc -l ) +echo v236 $( grep -w v236 $FILE | wc -l ) +echo v237 $( grep -w v237 $FILE | wc -l ) +echo v238 $( grep -w v238 $FILE | wc -l ) +echo v239 $( grep -w v239 $FILE | wc -l ) +echo v240 $( grep -w v240 $FILE | wc -l ) +echo v241 $( grep -w v241 $FILE | wc -l ) +echo v242 $( grep -w v242 $FILE | wc -l ) +echo v243 $( grep -w v243 $FILE | wc -l ) +echo v244 $( grep -w v244 $FILE | wc -l ) +echo v245 $( grep -w v245 $FILE | wc -l ) +echo v246 $( grep -w v246 $FILE | wc -l ) +echo v247 $( grep -w v247 $FILE | wc -l ) +echo v248 $( grep -w v248 $FILE | wc -l ) +echo v249 $( grep -w v249 $FILE | wc -l ) +echo v250 $( grep -w v250 $FILE | wc -l ) +echo v251 $( grep -w v251 $FILE | wc -l ) +echo v252 $( grep -w v252 $FILE | wc -l ) +echo v253 $( grep -w v253 $FILE | wc -l ) +echo v254 $( grep -w v254 $FILE | wc -l ) +echo v255 $( grep -w v255 $FILE | wc -l ) diff --git a/script/docker-rocm4.1.sh b/script/docker-rocm4.1.sh new file mode 100755 index 0000000000..61cc33c5b8 --- /dev/null +++ b/script/docker-rocm4.1.sh @@ -0,0 +1,14 @@ +WORKSPACE=$1 +echo "workspace: " $WORKSPACE + +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v $WORKSPACE:/root/workspace \ +rocm/tensorflow:rocm4.1-tf1.15-dev \ +/bin/bash + +#--network host \ diff --git a/script/hipclang_opt.sh b/script/hipclang_opt.sh new file mode 100755 index 0000000000..c51bd51d97 --- /dev/null +++ b/script/hipclang_opt.sh @@ -0,0 +1,25 @@ +rm *.ll *.s + +BC_FILE=$1 + +/opt/rocm/llvm/bin/llvm-dis $BC_FILE -o original.ll +/opt/rocm/llvm/bin/opt -S -inline -inline-threshold=104857 original.ll > inline.ll +/opt/rocm/llvm/bin/opt -S -sroa inline.ll > sroa.ll +/opt/rocm/llvm/bin/opt -S -O3 sroa.ll > o3.ll + +/opt/rocm/llvm/bin/llc -mcpu=gfx906 original.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 inline.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 sroa.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 o3.ll + +#/opt/rocm/llvm/bin/opt -S -O3 -sroa inline.ll > o3.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3.ll > o3_2.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3_2.ll > o3_3.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3_3.ll > o3_4.ll + +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 opt.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 inline.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_2.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_3.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_4.ll diff --git a/script/run.sh b/script/run.sh new file mode 100755 index 0000000000..ecb5c85d81 --- /dev/null +++ b/script/run.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +## GPU visibility + export ROCR_VISIBLE_DEVICE=0 + export GPU_DEVICE_ORDINAL=0 + +## Boost + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + +## Compiling +#export OLC_DEBUG_HIP_VERBOSE=1 +#export OLC_DEBUG_HIP_DUMP=1 +#export OLC_DEBUG_SAVE_TEMP_DIR=1 + + make -j conv_fwd_driver_offline + make -j conv_bwd_driver_offline + make -j conv_fwd_driver_online + +#rm -rf /root/_hip_binary_kernels_/ +#rm -rf /tmp/olCompile* + +LAYOUT=$1 +ALGO=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +REPEAT=$6 + +################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 + ./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 + +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 + +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + +#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + +#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1