mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Merge pull request #16 from ROCmSoftwarePlatform/develop
Merge develop into master
This commit is contained in:
3
.clang-tidy
Normal file
3
.clang-tidy
Normal file
@@ -0,0 +1,3 @@
|
||||
CheckOptions:
|
||||
- key: bugprone-reserved-identifier.AllowedIdentifiers
|
||||
value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__'
|
||||
164
CMakeLists.txt
164
CMakeLists.txt
@@ -1,10 +1,9 @@
|
||||
cmake_minimum_required(VERSION 2.8.3)
|
||||
project(modular_convolution)
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
project(composable_kernel)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
|
||||
include(TargetFlags)
|
||||
include(AddKernels)
|
||||
include(CheckCXXCompilerFlag)
|
||||
|
||||
## C++
|
||||
enable_language(CXX)
|
||||
@@ -39,4 +38,161 @@ link_libraries(${OpenMP_pthread_LIBRARY})
|
||||
find_package(HIP REQUIRED)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
## half
|
||||
#find_path(HALF_INCLUDE_DIR half.hpp)
|
||||
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
|
||||
|
||||
# CMAKE_CXX_FLAGS
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
## tidy
|
||||
include(EnableCompilerWarnings)
|
||||
set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
|
||||
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
|
||||
set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
|
||||
# Enable tidy on hip
|
||||
elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU")
|
||||
set(MIOPEN_TIDY_ERRORS ALL)
|
||||
endif()
|
||||
|
||||
include(ClangTidy)
|
||||
enable_clang_tidy(
|
||||
CHECKS
|
||||
*
|
||||
-abseil-*
|
||||
-android-cloexec-fopen
|
||||
# Yea we shouldn't be using rand()
|
||||
-cert-msc30-c
|
||||
-bugprone-exception-escape
|
||||
-bugprone-macro-parentheses
|
||||
-cert-env33-c
|
||||
-cert-msc32-c
|
||||
-cert-msc50-cpp
|
||||
-cert-msc51-cpp
|
||||
-cert-dcl37-c
|
||||
-cert-dcl51-cpp
|
||||
-clang-analyzer-alpha.core.CastToStruct
|
||||
-clang-analyzer-optin.performance.Padding
|
||||
-clang-diagnostic-deprecated-declarations
|
||||
-clang-diagnostic-extern-c-compat
|
||||
-clang-diagnostic-unused-command-line-argument
|
||||
-cppcoreguidelines-avoid-c-arrays
|
||||
-cppcoreguidelines-avoid-magic-numbers
|
||||
-cppcoreguidelines-explicit-virtual-functions
|
||||
-cppcoreguidelines-init-variables
|
||||
-cppcoreguidelines-macro-usage
|
||||
-cppcoreguidelines-non-private-member-variables-in-classes
|
||||
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
|
||||
-cppcoreguidelines-pro-bounds-constant-array-index
|
||||
-cppcoreguidelines-pro-bounds-pointer-arithmetic
|
||||
-cppcoreguidelines-pro-type-member-init
|
||||
-cppcoreguidelines-pro-type-reinterpret-cast
|
||||
-cppcoreguidelines-pro-type-union-access
|
||||
-cppcoreguidelines-pro-type-vararg
|
||||
-cppcoreguidelines-special-member-functions
|
||||
-fuchsia-*
|
||||
-google-explicit-constructor
|
||||
-google-readability-braces-around-statements
|
||||
-google-readability-todo
|
||||
-google-runtime-int
|
||||
-google-runtime-references
|
||||
-hicpp-vararg
|
||||
-hicpp-braces-around-statements
|
||||
-hicpp-explicit-conversions
|
||||
-hicpp-named-parameter
|
||||
-hicpp-no-array-decay
|
||||
# We really shouldn't use bitwise operators with signed integers, but
|
||||
# opencl leaves us no choice
|
||||
-hicpp-avoid-c-arrays
|
||||
-hicpp-signed-bitwise
|
||||
-hicpp-special-member-functions
|
||||
-hicpp-uppercase-literal-suffix
|
||||
-hicpp-use-auto
|
||||
-hicpp-use-equals-default
|
||||
-hicpp-use-override
|
||||
-llvm-header-guard
|
||||
-llvm-include-order
|
||||
#-llvmlibc-*
|
||||
-llvmlibc-restrict-system-libc-headers
|
||||
-llvmlibc-callee-namespace
|
||||
-llvmlibc-implementation-in-namespace
|
||||
-llvm-else-after-return
|
||||
-llvm-qualified-auto
|
||||
-misc-misplaced-const
|
||||
-misc-non-private-member-variables-in-classes
|
||||
-misc-no-recursion
|
||||
-modernize-avoid-bind
|
||||
-modernize-avoid-c-arrays
|
||||
-modernize-pass-by-value
|
||||
-modernize-use-auto
|
||||
-modernize-use-default-member-init
|
||||
-modernize-use-equals-default
|
||||
-modernize-use-trailing-return-type
|
||||
-modernize-use-transparent-functors
|
||||
-performance-unnecessary-value-param
|
||||
-readability-braces-around-statements
|
||||
-readability-else-after-return
|
||||
# we are not ready to use it, but very useful
|
||||
-readability-function-cognitive-complexity
|
||||
-readability-isolate-declaration
|
||||
-readability-magic-numbers
|
||||
-readability-named-parameter
|
||||
-readability-uppercase-literal-suffix
|
||||
-readability-convert-member-functions-to-static
|
||||
-readability-qualified-auto
|
||||
-readability-redundant-string-init
|
||||
# too many narrowing conversions in our code
|
||||
-bugprone-narrowing-conversions
|
||||
-cppcoreguidelines-narrowing-conversions
|
||||
-altera-struct-pack-align
|
||||
-cppcoreguidelines-prefer-member-initializer
|
||||
|
||||
${MIOPEN_TIDY_CHECKS}
|
||||
${MIOPEN_TIDY_ERRORS}
|
||||
HEADER_FILTER
|
||||
"\.hpp$"
|
||||
EXTRA_ARGS
|
||||
-DMIOPEN_USE_CLANG_TIDY
|
||||
)
|
||||
|
||||
include(CppCheck)
|
||||
enable_cppcheck(
|
||||
CHECKS
|
||||
warning
|
||||
style
|
||||
performance
|
||||
portability
|
||||
SUPPRESS
|
||||
ConfigurationNotChecked
|
||||
constStatement
|
||||
duplicateCondition
|
||||
noExplicitConstructor
|
||||
passedByValue
|
||||
preprocessorErrorDirective
|
||||
shadowVariable
|
||||
unusedFunction
|
||||
unusedPrivateFunction
|
||||
unusedStructMember
|
||||
unmatchedSuppression
|
||||
FORCE
|
||||
SOURCES
|
||||
host/host_tensor/src
|
||||
host/driver_offline/src
|
||||
composable_kernel/src/kernel_wrapper
|
||||
INCLUDE
|
||||
host/host_tensor/include
|
||||
host/solver/include
|
||||
host/driver_offline/include
|
||||
composable_kernel/include/*
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||
DEFINE
|
||||
CPPCHECK=1
|
||||
__linux__=1
|
||||
)
|
||||
|
||||
add_subdirectory(host)
|
||||
|
||||
10
README.md
10
README.md
@@ -78,7 +78,7 @@ InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {2, 2, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||
a_k0_m_k1_grid_desc{216, 256, 8}
|
||||
b_k0_n_k1_grid_desc{216, 165888, 8}
|
||||
c_m_n_grid_desc{ 256, 165888}
|
||||
@@ -100,7 +100,7 @@ InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||
a_k0_m_k1_grid_desc{288, 1024, 8}
|
||||
b_k0_n_k1_grid_desc{288, 50176, 8}
|
||||
c_m_n_grid_desc{ 1024, 50176}
|
||||
@@ -122,7 +122,7 @@ InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {2, 2, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{216, 165888, 8}
|
||||
b_k0_n_k1_grid_desc{216, 256, 8}
|
||||
c_m_n_grid_desc{ 165888, 256}
|
||||
@@ -144,7 +144,7 @@ InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{288, 50176, 8}
|
||||
b_k0_n_k1_grid_desc{288, 1024, 8}
|
||||
c_m_n_grid_desc{ 50176, 1024}
|
||||
@@ -166,7 +166,7 @@ InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{288, 50176, 8}
|
||||
b_k0_n_k1_grid_desc{288, 1024, 8}
|
||||
c_m_n_grid_desc{ 50176, 1024}
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
|
||||
function(add_kernels SRC_DIR KERNEL_FILES)
|
||||
set(INIT_KERNELS_LIST)
|
||||
set(KERNELS_DECLS)
|
||||
foreach(KERNEL_FILE ${KERNEL_FILES})
|
||||
if("${CMAKE_VERSION}" VERSION_LESS 3.0)
|
||||
configure_file(${KERNEL_FILE} ${KERNEL_FILE}.delete)
|
||||
else()
|
||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${KERNEL_FILE})
|
||||
endif()
|
||||
get_filename_component(BASE_NAME ${KERNEL_FILE} NAME_WE)
|
||||
string(TOUPPER "${BASE_NAME}" KEY_NAME)
|
||||
string(MAKE_C_IDENTIFIER "${KEY_NAME}" VAR_NAME)
|
||||
string(APPEND KERNELS_DECLS "extern const size_t APP_KERNEL_${VAR_NAME}_SIZE;\n")
|
||||
string(APPEND KERNELS_DECLS "extern const unsigned char APP_KERNEL_${VAR_NAME}[];\n")
|
||||
list(APPEND INIT_KERNELS_LIST " { \"${KEY_NAME}\", std::string(reinterpret_cast<const char*>(APP_KERNEL_${VAR_NAME}), APP_KERNEL_${VAR_NAME}_SIZE) }")
|
||||
endforeach()
|
||||
string(REPLACE ";" ",\n" INIT_KERNELS "${INIT_KERNELS_LIST}")
|
||||
configure_file(${SRC_DIR}/kernel.cpp.in ${PROJECT_BINARY_DIR}/kernel.cpp)
|
||||
endfunction()
|
||||
|
||||
function(add_kernel_includes SRC_DIR KERNEL_FILES)
|
||||
set(INIT_KERNELS_LIST)
|
||||
foreach(KERNEL_FILE ${KERNEL_FILES})
|
||||
if("${CMAKE_VERSION}" VERSION_LESS 3.0)
|
||||
configure_file(${KERNEL_FILE} ${KERNEL_FILE}.delete)
|
||||
else()
|
||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${KERNEL_FILE})
|
||||
endif()
|
||||
get_filename_component(BASE_NAME ${KERNEL_FILE} NAME_WE)
|
||||
get_filename_component(FILE_NAME ${KERNEL_FILE} NAME)
|
||||
string(TOUPPER "${BASE_NAME}" KEY_NAME)
|
||||
string(MAKE_C_IDENTIFIER "${KEY_NAME}" VAR_NAME)
|
||||
list(APPEND INIT_KERNELS_LIST " { \"${FILE_NAME}\", std::string(reinterpret_cast<const char*>(${VAR_NAME}), ${VAR_NAME}_SIZE) }")
|
||||
endforeach()
|
||||
string(REPLACE ";" ",\n" INIT_KERNELS "${INIT_KERNELS_LIST}")
|
||||
configure_file(${SRC_DIR}/kernel_includes.cpp.in ${PROJECT_BINARY_DIR}/kernel_includes.cpp)
|
||||
endfunction()
|
||||
|
||||
|
||||
@@ -24,7 +24,11 @@
|
||||
#
|
||||
################################################################################
|
||||
|
||||
set(ADD_KERNELS_SOURCE include_inliner.cpp addkernels.cpp)
|
||||
if(NOT TARGET analyze)
|
||||
add_custom_target(analyze)
|
||||
endif()
|
||||
|
||||
add_executable(addkernels EXCLUDE_FROM_ALL ${ADD_KERNELS_SOURCE})
|
||||
function(mark_as_analyzer)
|
||||
add_dependencies(analyze ${ARGN})
|
||||
endfunction()
|
||||
|
||||
162
cmake/ClangTidy.cmake
Normal file
162
cmake/ClangTidy.cmake
Normal file
@@ -0,0 +1,162 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
include(CMakeParseArguments)
|
||||
include(Analyzers)
|
||||
|
||||
get_filename_component(CLANG_TIDY_EXE_HINT "${CMAKE_CXX_COMPILER}" PATH)
|
||||
|
||||
find_program(CLANG_TIDY_EXE
|
||||
NAMES
|
||||
clang-tidy
|
||||
clang-tidy-5.0
|
||||
clang-tidy-4.0
|
||||
clang-tidy-3.9
|
||||
clang-tidy-3.8
|
||||
clang-tidy-3.7
|
||||
clang-tidy-3.6
|
||||
clang-tidy-3.5
|
||||
HINTS
|
||||
${CLANG_TIDY_EXE_HINT}
|
||||
PATH_SUFFIXES
|
||||
compiler/bin
|
||||
PATHS
|
||||
/opt/rocm/llvm/bin
|
||||
/opt/rocm/hcc
|
||||
/usr/local/opt/llvm/bin
|
||||
)
|
||||
|
||||
function(find_clang_tidy_version VAR)
|
||||
execute_process(COMMAND ${CLANG_TIDY_EXE} -version OUTPUT_VARIABLE VERSION_OUTPUT)
|
||||
separate_arguments(VERSION_OUTPUT_LIST UNIX_COMMAND "${VERSION_OUTPUT}")
|
||||
list(FIND VERSION_OUTPUT_LIST "version" VERSION_INDEX)
|
||||
if(VERSION_INDEX GREATER 0)
|
||||
math(EXPR VERSION_INDEX "${VERSION_INDEX} + 1")
|
||||
list(GET VERSION_OUTPUT_LIST ${VERSION_INDEX} VERSION)
|
||||
set(${VAR} ${VERSION} PARENT_SCOPE)
|
||||
else()
|
||||
set(${VAR} "0.0" PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
if( NOT CLANG_TIDY_EXE )
|
||||
message( STATUS "Clang tidy not found" )
|
||||
set(CLANG_TIDY_VERSION "0.0")
|
||||
else()
|
||||
find_clang_tidy_version(CLANG_TIDY_VERSION)
|
||||
message( STATUS "Clang tidy found: ${CLANG_TIDY_VERSION}")
|
||||
endif()
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(CLANG_TIDY_FIXIT_DIR ${CMAKE_BINARY_DIR}/fixits)
|
||||
file(MAKE_DIRECTORY ${CLANG_TIDY_FIXIT_DIR})
|
||||
set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CLANG_TIDY_FIXIT_DIR})
|
||||
|
||||
macro(enable_clang_tidy)
|
||||
set(options ANALYZE_TEMPORARY_DTORS ALL)
|
||||
set(oneValueArgs HEADER_FILTER)
|
||||
set(multiValueArgs CHECKS ERRORS EXTRA_ARGS)
|
||||
|
||||
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
string(REPLACE ";" "," CLANG_TIDY_CHECKS "${PARSE_CHECKS}")
|
||||
string(REPLACE ";" "," CLANG_TIDY_ERRORS "${PARSE_ERRORS}")
|
||||
set(CLANG_TIDY_EXTRA_ARGS)
|
||||
foreach(ARG ${PARSE_EXTRA_ARGS})
|
||||
list(APPEND CLANG_TIDY_EXTRA_ARGS "-extra-arg=${ARG}")
|
||||
endforeach()
|
||||
|
||||
set(CLANG_TIDY_ALL)
|
||||
if(PARSE_ALL)
|
||||
set(CLANG_TIDY_ALL ALL)
|
||||
endif()
|
||||
|
||||
message(STATUS "Clang tidy checks: ${CLANG_TIDY_CHECKS}")
|
||||
|
||||
if (${PARSE_ANALYZE_TEMPORARY_DTORS})
|
||||
set(CLANG_TIDY_ANALYZE_TEMPORARY_DTORS "-analyze-temporary-dtors")
|
||||
endif()
|
||||
|
||||
if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0")
|
||||
set(CLANG_TIDY_ERRORS_ARG "")
|
||||
else()
|
||||
set(CLANG_TIDY_ERRORS_ARG "-warnings-as-errors='${CLANG_TIDY_ERRORS}'")
|
||||
endif()
|
||||
|
||||
if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0")
|
||||
set(CLANG_TIDY_QUIET_ARG "")
|
||||
else()
|
||||
set(CLANG_TIDY_QUIET_ARG "-quiet")
|
||||
endif()
|
||||
|
||||
if(PARSE_HEADER_FILTER)
|
||||
string(REPLACE "$" "$$" CLANG_TIDY_HEADER_FILTER "${PARSE_HEADER_FILTER}")
|
||||
else()
|
||||
set(CLANG_TIDY_HEADER_FILTER ".*")
|
||||
endif()
|
||||
|
||||
set(CLANG_TIDY_COMMAND
|
||||
${CLANG_TIDY_EXE}
|
||||
${CLANG_TIDY_QUIET_ARG}
|
||||
-p ${CMAKE_BINARY_DIR}
|
||||
-checks='${CLANG_TIDY_CHECKS}'
|
||||
${CLANG_TIDY_ERRORS_ARG}
|
||||
${CLANG_TIDY_EXTRA_ARGS}
|
||||
${CLANG_TIDY_ANALYZE_TEMPORARY_DTORS}
|
||||
-header-filter='${CLANG_TIDY_HEADER_FILTER}'
|
||||
)
|
||||
add_custom_target(tidy ${CLANG_TIDY_ALL})
|
||||
mark_as_analyzer(tidy)
|
||||
add_custom_target(tidy-base)
|
||||
add_custom_target(tidy-make-fixit-dir COMMAND ${CMAKE_COMMAND} -E make_directory ${CLANG_TIDY_FIXIT_DIR})
|
||||
add_custom_target(tidy-rm-fixit-dir COMMAND ${CMAKE_COMMAND} -E remove_directory ${CLANG_TIDY_FIXIT_DIR})
|
||||
add_dependencies(tidy-make-fixit-dir tidy-rm-fixit-dir)
|
||||
add_dependencies(tidy-base tidy-make-fixit-dir)
|
||||
endmacro()
|
||||
|
||||
function(clang_tidy_check TARGET)
|
||||
get_target_property(SOURCES ${TARGET} SOURCES)
|
||||
# TODO: Use generator expressions instead
|
||||
# COMMAND ${CLANG_TIDY_COMMAND} $<TARGET_PROPERTY:${TARGET},SOURCES>
|
||||
# COMMAND ${CLANG_TIDY_COMMAND} $<JOIN:$<TARGET_PROPERTY:${TARGET},SOURCES>, >
|
||||
foreach(SOURCE ${SOURCES})
|
||||
if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS"))
|
||||
string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file)
|
||||
set(tidy_target tidy-target-${TARGET}-${tidy_file})
|
||||
add_custom_target(${tidy_target}
|
||||
# for some targets clang-tidy not able to get information from .clang-tidy
|
||||
DEPENDS ${SOURCE}
|
||||
COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..."
|
||||
)
|
||||
add_dependencies(${tidy_target} ${TARGET})
|
||||
add_dependencies(${tidy_target} tidy-base)
|
||||
add_dependencies(tidy ${tidy_target})
|
||||
endif()
|
||||
endforeach()
|
||||
endfunction()
|
||||
|
||||
130
cmake/CppCheck.cmake
Normal file
130
cmake/CppCheck.cmake
Normal file
@@ -0,0 +1,130 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
include(CMakeParseArguments)
|
||||
include(ProcessorCount)
|
||||
include(Analyzers)
|
||||
|
||||
find_program(CPPCHECK_EXE
|
||||
NAMES
|
||||
cppcheck
|
||||
PATHS
|
||||
/opt/rocm/bin
|
||||
)
|
||||
|
||||
ProcessorCount(CPPCHECK_JOBS)
|
||||
|
||||
set(CPPCHECK_BUILD_DIR ${CMAKE_BINARY_DIR}/cppcheck-build)
|
||||
file(MAKE_DIRECTORY ${CPPCHECK_BUILD_DIR})
|
||||
set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CPPCHECK_BUILD_DIR})
|
||||
|
||||
macro(enable_cppcheck)
|
||||
set(options FORCE)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs CHECKS SUPPRESS DEFINE UNDEFINE INCLUDE SOURCES)
|
||||
|
||||
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
string(REPLACE ";" "," CPPCHECK_CHECKS "${PARSE_CHECKS}")
|
||||
string(REPLACE ";" "\n" CPPCHECK_SUPPRESS "${PARSE_SUPPRESS};*:/usr/*")
|
||||
file(WRITE ${CMAKE_BINARY_DIR}/cppcheck-supressions "${CPPCHECK_SUPPRESS}")
|
||||
set(CPPCHECK_DEFINES)
|
||||
foreach(DEF ${PARSE_DEFINE})
|
||||
set(CPPCHECK_DEFINES "${CPPCHECK_DEFINES} -D${DEF}")
|
||||
endforeach()
|
||||
|
||||
set(CPPCHECK_UNDEFINES)
|
||||
foreach(DEF ${PARSE_UNDEFINE})
|
||||
set(CPPCHECK_UNDEFINES "${CPPCHECK_UNDEFINES} -U${DEF}")
|
||||
endforeach()
|
||||
|
||||
set(CPPCHECK_INCLUDES)
|
||||
foreach(INC ${PARSE_INCLUDE})
|
||||
set(CPPCHECK_INCLUDES "${CPPCHECK_INCLUDES} -I${INC}")
|
||||
endforeach()
|
||||
|
||||
# set(CPPCHECK_FORCE)
|
||||
set(CPPCHECK_FORCE "--project=${CMAKE_BINARY_DIR}/compile_commands.json")
|
||||
if(PARSE_FORCE)
|
||||
set(CPPCHECK_FORCE --force)
|
||||
endif()
|
||||
|
||||
set(SOURCES)
|
||||
set(GLOBS)
|
||||
foreach(SOURCE ${PARSE_SOURCES})
|
||||
get_filename_component(ABS_SOURCE ${SOURCE} ABSOLUTE)
|
||||
if(EXISTS ${ABS_SOURCE})
|
||||
if(IS_DIRECTORY ${ABS_SOURCE})
|
||||
set(GLOBS "${GLOBS} ${ABS_SOURCE}/*.cpp ${ABS_SOURCE}/*.hpp ${ABS_SOURCE}/*.cxx ${ABS_SOURCE}/*.c ${ABS_SOURCE}/*.h")
|
||||
else()
|
||||
set(SOURCES "${SOURCES} ${ABS_SOURCE}")
|
||||
endif()
|
||||
else()
|
||||
set(GLOBS "${GLOBS} ${ABS_SOURCE}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
file(WRITE ${CMAKE_BINARY_DIR}/cppcheck.cmake "
|
||||
file(GLOB_RECURSE GSRCS ${GLOBS})
|
||||
set(CPPCHECK_COMMAND
|
||||
${CPPCHECK_EXE}
|
||||
-q
|
||||
# -v
|
||||
# --report-progress
|
||||
${CPPCHECK_FORCE}
|
||||
--cppcheck-build-dir=${CPPCHECK_BUILD_DIR}
|
||||
--platform=native
|
||||
--template=gcc
|
||||
--error-exitcode=1
|
||||
-j ${CPPCHECK_JOBS}
|
||||
${CPPCHECK_DEFINES}
|
||||
${CPPCHECK_UNDEFINES}
|
||||
${CPPCHECK_INCLUDES}
|
||||
--enable=${CPPCHECK_CHECKS}
|
||||
--inline-suppr
|
||||
--suppressions-list=${CMAKE_BINARY_DIR}/cppcheck-supressions
|
||||
${SOURCES} \${GSRCS}
|
||||
)
|
||||
string(REPLACE \";\" \" \" CPPCHECK_SHOW_COMMAND \"\${CPPCHECK_COMMAND}\")
|
||||
message(\"\${CPPCHECK_SHOW_COMMAND}\")
|
||||
execute_process(
|
||||
COMMAND \${CPPCHECK_COMMAND}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE RESULT
|
||||
)
|
||||
if(NOT RESULT EQUAL 0)
|
||||
message(FATAL_ERROR \"Cppcheck failed\")
|
||||
endif()
|
||||
")
|
||||
|
||||
add_custom_target(cppcheck
|
||||
COMMAND ${CMAKE_COMMAND} -P ${CMAKE_BINARY_DIR}/cppcheck.cmake
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMENT "cppcheck: Running cppcheck..."
|
||||
)
|
||||
mark_as_analyzer(cppcheck)
|
||||
endmacro()
|
||||
|
||||
|
||||
355
cmake/DoxygenDoc.cmake
Normal file
355
cmake/DoxygenDoc.cmake
Normal file
@@ -0,0 +1,355 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
include(CMakeParseArguments)
|
||||
include(MainDoc)
|
||||
|
||||
find_program(DOXYGEN_EXECUTABLE NAMES doxygen
|
||||
PATH_SUFFIXES bin
|
||||
DOC "Doxygen documentation generator"
|
||||
)
|
||||
mark_as_advanced(DOXYGEN_EXECUTABLE)
|
||||
|
||||
find_path(DOT_EXECUTABLE NAMES dot
|
||||
PATH_SUFFIXES bin
|
||||
DOC "Graphviz"
|
||||
)
|
||||
mark_as_advanced(DOT_EXECUTABLE)
|
||||
|
||||
set(DOXYGEN_ARGS
|
||||
ABBREVIATE_BRIEF
|
||||
ALIASES
|
||||
ALLEXTERNALS
|
||||
ALLOW_UNICODE_NAMES
|
||||
ALPHABETICAL_INDEX
|
||||
ALWAYS_DETAILED_SEC
|
||||
AUTOLINK_SUPPORT
|
||||
BINARY_TOC
|
||||
BRIEF_MEMBER_DESC
|
||||
BUILTIN_STL_SUPPORT
|
||||
CALLER_GRAPH
|
||||
CALL_GRAPH
|
||||
CASE_SENSE_NAMES
|
||||
CHM_FILE
|
||||
CHM_INDEX_ENCODING
|
||||
CITE_BIB_FILES
|
||||
CLANG_ASSISTED_PARSING
|
||||
CLANG_OPTIONS
|
||||
CLASS_DIAGRAMS
|
||||
CLASS_GRAPH
|
||||
COLLABORATION_GRAPH
|
||||
COLS_IN_ALPHA_INDEX
|
||||
COMPACT_LATEX
|
||||
COMPACT_RTF
|
||||
CPP_CLI_SUPPORT
|
||||
CREATE_SUBDIRS
|
||||
DIAFILE_DIRS
|
||||
DIA_PATH
|
||||
DIRECTORY_GRAPH
|
||||
DISABLE_INDEX
|
||||
DISTRIBUTE_GROUP_DOC
|
||||
DOCBOOK_OUTPUT
|
||||
DOCBOOK_PROGRAMLISTING
|
||||
DOCSET_BUNDLE_ID
|
||||
DOCSET_FEEDNAME
|
||||
DOCSET_PUBLISHER_ID
|
||||
DOCSET_PUBLISHER_NAME
|
||||
DOTFILE_DIRS
|
||||
DOT_CLEANUP
|
||||
DOT_FONTNAME
|
||||
DOT_FONTPATH
|
||||
DOT_FONTSIZE
|
||||
DOT_GRAPH_MAX_NODES
|
||||
DOT_IMAGE_FORMAT
|
||||
DOT_MULTI_TARGETS
|
||||
DOT_NUM_THREADS
|
||||
# DOT_PATH
|
||||
DOT_TRANSPARENT
|
||||
DOXYFILE_ENCODING
|
||||
ECLIPSE_DOC_ID
|
||||
ENABLED_SECTIONS
|
||||
ENABLE_PREPROCESSING
|
||||
ENUM_VALUES_PER_LINE
|
||||
EXAMPLE_PATH
|
||||
EXAMPLE_PATTERNS
|
||||
EXAMPLE_RECURSIVE
|
||||
EXCLUDE
|
||||
EXCLUDE_PATTERNS
|
||||
EXCLUDE_SYMBOLS
|
||||
EXCLUDE_SYMLINKS
|
||||
EXPAND_AS_DEFINED
|
||||
EXPAND_ONLY_PREDEF
|
||||
EXTENSION_MAPPING
|
||||
EXTERNAL_GROUPS
|
||||
EXTERNAL_PAGES
|
||||
EXTERNAL_SEARCH
|
||||
EXTERNAL_SEARCH_ID
|
||||
EXTRACT_ALL
|
||||
EXTRACT_ANON_NSPACES
|
||||
EXTRACT_LOCAL_CLASSES
|
||||
EXTRACT_LOCAL_METHODS
|
||||
EXTRACT_PACKAGE
|
||||
EXTRACT_PRIVATE
|
||||
EXTRACT_STATIC
|
||||
EXTRA_PACKAGES
|
||||
EXTRA_SEARCH_MAPPINGS
|
||||
EXT_LINKS_IN_WINDOW
|
||||
FILE_PATTERNS
|
||||
FILE_VERSION_FILTER
|
||||
FILTER_PATTERNS
|
||||
FILTER_SOURCE_FILES
|
||||
FILTER_SOURCE_PATTERNS
|
||||
FORCE_LOCAL_INCLUDES
|
||||
FORMULA_FONTSIZE
|
||||
FORMULA_TRANSPARENT
|
||||
FULL_PATH_NAMES
|
||||
GENERATE_AUTOGEN_DEF
|
||||
GENERATE_BUGLIST
|
||||
GENERATE_CHI
|
||||
GENERATE_DEPRECATEDLIST
|
||||
GENERATE_DOCBOOK
|
||||
GENERATE_DOCSET
|
||||
GENERATE_ECLIPSEHELP
|
||||
GENERATE_HTML
|
||||
GENERATE_HTMLHELP
|
||||
GENERATE_LATEX
|
||||
GENERATE_LEGEND
|
||||
GENERATE_MAN
|
||||
GENERATE_PERLMOD
|
||||
GENERATE_QHP
|
||||
GENERATE_RTF
|
||||
GENERATE_TAGFILE
|
||||
GENERATE_TESTLIST
|
||||
GENERATE_TODOLIST
|
||||
GENERATE_TREEVIEW
|
||||
GENERATE_XML
|
||||
GRAPHICAL_HIERARCHY
|
||||
GROUP_GRAPHS
|
||||
GROUP_NESTED_COMPOUNDS
|
||||
# HAVE_DOT
|
||||
HHC_LOCATION
|
||||
HIDE_COMPOUND_REFERENCE
|
||||
HIDE_FRIEND_COMPOUNDS
|
||||
HIDE_IN_BODY_DOCS
|
||||
HIDE_SCOPE_NAMES
|
||||
HIDE_UNDOC_CLASSES
|
||||
HIDE_UNDOC_MEMBERS
|
||||
HIDE_UNDOC_RELATIONS
|
||||
HTML_COLORSTYLE_GAMMA
|
||||
HTML_COLORSTYLE_HUE
|
||||
HTML_COLORSTYLE_SAT
|
||||
HTML_DYNAMIC_SECTIONS
|
||||
HTML_EXTRA_FILES
|
||||
HTML_EXTRA_STYLESHEET
|
||||
HTML_FILE_EXTENSION
|
||||
HTML_FOOTER
|
||||
HTML_HEADER
|
||||
HTML_INDEX_NUM_ENTRIES
|
||||
HTML_OUTPUT
|
||||
HTML_STYLESHEET
|
||||
HTML_TIMESTAMP
|
||||
IDL_PROPERTY_SUPPORT
|
||||
IGNORE_PREFIX
|
||||
IMAGE_PATH
|
||||
INCLUDED_BY_GRAPH
|
||||
INCLUDE_FILE_PATTERNS
|
||||
INCLUDE_GRAPH
|
||||
INCLUDE_PATH
|
||||
INHERIT_DOCS
|
||||
INLINE_GROUPED_CLASSES
|
||||
INLINE_INFO
|
||||
INLINE_INHERITED_MEMB
|
||||
INLINE_SIMPLE_STRUCTS
|
||||
INLINE_SOURCES
|
||||
INPUT
|
||||
INPUT_ENCODING
|
||||
INPUT_FILTER
|
||||
INTERACTIVE_SVG
|
||||
INTERNAL_DOCS
|
||||
JAVADOC_AUTOBRIEF
|
||||
LATEX_BATCHMODE
|
||||
LATEX_BIB_STYLE
|
||||
LATEX_CMD_NAME
|
||||
LATEX_EXTRA_FILES
|
||||
LATEX_EXTRA_STYLESHEET
|
||||
LATEX_FOOTER
|
||||
LATEX_HEADER
|
||||
LATEX_HIDE_INDICES
|
||||
LATEX_OUTPUT
|
||||
LATEX_SOURCE_CODE
|
||||
LATEX_TIMESTAMP
|
||||
LAYOUT_FILE
|
||||
LOOKUP_CACHE_SIZE
|
||||
MACRO_EXPANSION
|
||||
MAKEINDEX_CMD_NAME
|
||||
MAN_EXTENSION
|
||||
MAN_LINKS
|
||||
MAN_OUTPUT
|
||||
MAN_SUBDIR
|
||||
MARKDOWN_SUPPORT
|
||||
MATHJAX_CODEFILE
|
||||
MATHJAX_EXTENSIONS
|
||||
MATHJAX_FORMAT
|
||||
MATHJAX_RELPATH
|
||||
MAX_DOT_GRAPH_DEPTH
|
||||
MAX_INITIALIZER_LINES
|
||||
MSCFILE_DIRS
|
||||
MSCGEN_PATH
|
||||
MULTILINE_CPP_IS_BRIEF
|
||||
OPTIMIZE_FOR_FORTRAN
|
||||
OPTIMIZE_OUTPUT_FOR_C
|
||||
OPTIMIZE_OUTPUT_JAVA
|
||||
OPTIMIZE_OUTPUT_VHDL
|
||||
OUTPUT_DIRECTORY
|
||||
OUTPUT_LANGUAGE
|
||||
PAPER_TYPE
|
||||
PDF_HYPERLINKS
|
||||
PERLMOD_LATEX
|
||||
PERLMOD_MAKEVAR_PREFIX
|
||||
PERLMOD_PRETTY
|
||||
PERL_PATH
|
||||
PLANTUML_CFG_FILE
|
||||
PLANTUML_INCLUDE_PATH
|
||||
PLANTUML_JAR_PATH
|
||||
PREDEFINED
|
||||
PROJECT_BRIEF
|
||||
PROJECT_LOGO
|
||||
PROJECT_NAME
|
||||
PROJECT_NUMBER
|
||||
QCH_FILE
|
||||
QHG_LOCATION
|
||||
QHP_CUST_FILTER_ATTRS
|
||||
QHP_CUST_FILTER_NAME
|
||||
QHP_NAMESPACE
|
||||
QHP_SECT_FILTER_ATTRS
|
||||
QHP_VIRTUAL_FOLDER
|
||||
QT_AUTOBRIEF
|
||||
QUIET
|
||||
RECURSIVE
|
||||
REFERENCED_BY_RELATION
|
||||
REFERENCES_LINK_SOURCE
|
||||
REFERENCES_RELATION
|
||||
REPEAT_BRIEF
|
||||
RTF_EXTENSIONS_FILE
|
||||
RTF_HYPERLINKS
|
||||
RTF_OUTPUT
|
||||
RTF_SOURCE_CODE
|
||||
RTF_STYLESHEET_FILE
|
||||
SEARCHDATA_FILE
|
||||
SEARCHENGINE
|
||||
SEARCHENGINE_URL
|
||||
SEARCH_INCLUDES
|
||||
SEPARATE_MEMBER_PAGES
|
||||
SERVER_BASED_SEARCH
|
||||
SHORT_NAMES
|
||||
SHOW_FILES
|
||||
SHOW_GROUPED_MEMB_INC
|
||||
SHOW_INCLUDE_FILES
|
||||
SHOW_NAMESPACES
|
||||
SHOW_USED_FILES
|
||||
SIP_SUPPORT
|
||||
SKIP_FUNCTION_MACROS
|
||||
SORT_BRIEF_DOCS
|
||||
SORT_BY_SCOPE_NAME
|
||||
SORT_GROUP_NAMES
|
||||
SORT_MEMBERS_CTORS_1ST
|
||||
SORT_MEMBER_DOCS
|
||||
SOURCE_BROWSER
|
||||
SOURCE_TOOLTIPS
|
||||
STRICT_PROTO_MATCHING
|
||||
STRIP_CODE_COMMENTS
|
||||
STRIP_FROM_INC_PATH
|
||||
STRIP_FROM_PATH
|
||||
SUBGROUPING
|
||||
TAB_SIZE
|
||||
TAGFILES
|
||||
TCL_SUBST
|
||||
TEMPLATE_RELATIONS
|
||||
TOC_EXPAND
|
||||
TOC_INCLUDE_HEADINGS
|
||||
TREEVIEW_WIDTH
|
||||
TYPEDEF_HIDES_STRUCT
|
||||
UML_LIMIT_NUM_FIELDS
|
||||
UML_LOOK
|
||||
USE_HTAGS
|
||||
USE_MATHJAX
|
||||
USE_MDFILE_AS_MAINPAGE
|
||||
USE_PDFLATEX
|
||||
VERBATIM_HEADERS
|
||||
WARNINGS
|
||||
WARN_AS_ERROR
|
||||
WARN_FORMAT
|
||||
WARN_IF_DOC_ERROR
|
||||
WARN_IF_UNDOCUMENTED
|
||||
WARN_LOGFILE
|
||||
WARN_NO_PARAMDOC
|
||||
XML_OUTPUT
|
||||
XML_PROGRAMLISTING
|
||||
)
|
||||
|
||||
set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file")
|
||||
|
||||
function(add_doxygen_doc)
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs DEPENDS ${DOXYGEN_ARGS})
|
||||
|
||||
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n")
|
||||
|
||||
foreach(ARG ${DOXYGEN_ARGS})
|
||||
if(PARSE_${ARG})
|
||||
string(REPLACE ";" " " ARG_VALUE ${PARSE_${ARG}})
|
||||
file(APPEND ${DOXYGEN_CONFIG_FILE} "\n${ARG} = ${ARG_VALUE}\n")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(PARSE_OUTPUT_DIRECTORY)
|
||||
if(NOT EXISTS ${PARSE_OUTPUT_DIRECTORY})
|
||||
file(MAKE_DIRECTORY ${PARSE_OUTPUT_DIRECTORY})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DOT_EXECUTABLE)
|
||||
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nDOT_PATH = \"${DOT_EXECUTABLE}\"\n")
|
||||
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = YES\n")
|
||||
else()
|
||||
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = NO\n")
|
||||
endif()
|
||||
|
||||
add_custom_target(doxygen
|
||||
${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMENT "Building documentation with doxygen"
|
||||
)
|
||||
if(PARSE_OUTPUT_DIRECTORY)
|
||||
clean_doc_output(${PARSE_OUTPUT_DIRECTORY})
|
||||
endif()
|
||||
mark_as_doc(doxygen)
|
||||
if(PARSE_DEPENDS)
|
||||
add_dependencies(doxygen ${PARSE_DEPENDS})
|
||||
endif()
|
||||
endfunction()
|
||||
110
cmake/EnableCompilerWarnings.cmake
Normal file
110
cmake/EnableCompilerWarnings.cmake
Normal file
@@ -0,0 +1,110 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
# - Enable warning all for gcc/clang or use /W4 for visual studio
|
||||
|
||||
## Strict warning level
|
||||
if (MSVC)
|
||||
# Use the highest warning level for visual studio.
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w")
|
||||
# set(CMAKE_CXX_WARNING_LEVEL 4)
|
||||
# if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]")
|
||||
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
# else ()
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
|
||||
# endif ()
|
||||
|
||||
# set(CMAKE_C_WARNING_LEVEL 4)
|
||||
# if (CMAKE_C_FLAGS MATCHES "/W[0-4]")
|
||||
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||
# else ()
|
||||
# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4")
|
||||
# endif ()
|
||||
|
||||
else()
|
||||
foreach(COMPILER C CXX)
|
||||
set(CMAKE_COMPILER_WARNINGS)
|
||||
# use -Wall for gcc and clang
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Wall
|
||||
-Wextra
|
||||
-Wcomment
|
||||
-Wendif-labels
|
||||
-Wformat
|
||||
-Winit-self
|
||||
-Wreturn-type
|
||||
-Wsequence-point
|
||||
# Shadow is broken on gcc when using lambdas
|
||||
# -Wshadow
|
||||
-Wswitch
|
||||
-Wtrigraphs
|
||||
-Wundef
|
||||
-Wuninitialized
|
||||
-Wunreachable-code
|
||||
-Wunused
|
||||
|
||||
-Wno-sign-compare
|
||||
-Wno-extra-semi-stmt
|
||||
)
|
||||
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Weverything
|
||||
-Wno-c++98-compat
|
||||
-Wno-c++98-compat-pedantic
|
||||
-Wno-conversion
|
||||
-Wno-double-promotion
|
||||
-Wno-exit-time-destructors
|
||||
-Wno-extra-semi
|
||||
-Wno-float-conversion
|
||||
-Wno-gnu-anonymous-struct
|
||||
-Wno-gnu-zero-variadic-macro-arguments
|
||||
-Wno-missing-prototypes
|
||||
-Wno-nested-anon-types
|
||||
-Wno-padded
|
||||
-Wno-return-std-move-in-c++11
|
||||
-Wno-shorten-64-to-32
|
||||
-Wno-sign-conversion
|
||||
-Wno-unknown-warning-option
|
||||
-Wno-unused-command-line-argument
|
||||
-Wno-weak-vtables
|
||||
-Wno-covered-switch-default
|
||||
)
|
||||
else()
|
||||
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")
|
||||
# cmake 3.5.2 does not support >=.
|
||||
if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.1")
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Wno-ignored-attributes)
|
||||
endif()
|
||||
endif()
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Wno-missing-field-initializers
|
||||
-Wno-deprecated-declarations
|
||||
)
|
||||
endif()
|
||||
add_definitions(${CMAKE_COMPILER_WARNINGS})
|
||||
endforeach()
|
||||
endif ()
|
||||
@@ -1,50 +0,0 @@
|
||||
|
||||
function(get_target_property2 VAR TARGET PROPERTY)
|
||||
get_target_property(_pflags ${TARGET} ${PROPERTY})
|
||||
if(_pflags)
|
||||
set(${VAR} ${_pflags} PARENT_SCOPE)
|
||||
else()
|
||||
set(${VAR} "" PARENT_SCOPE)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
macro(append_flags FLAGS TARGET PROPERTY PREFIX)
|
||||
get_target_property2(_pflags ${TARGET} ${PROPERTY})
|
||||
foreach(FLAG ${_pflags})
|
||||
if(TARGET ${FLAG})
|
||||
target_flags(_pflags2 ${FLAG})
|
||||
string(APPEND ${FLAGS} " ${_pflags2}")
|
||||
else()
|
||||
string(APPEND ${FLAGS} " ${PREFIX}${FLAG}")
|
||||
endif()
|
||||
endforeach()
|
||||
endmacro()
|
||||
|
||||
macro(append_link_flags FLAGS TARGET PROPERTY)
|
||||
get_target_property2(_pflags ${TARGET} ${PROPERTY})
|
||||
foreach(FLAG ${_pflags})
|
||||
if(TARGET ${FLAG})
|
||||
target_flags(_pflags2 ${FLAG})
|
||||
string(APPEND ${FLAGS} " ${_pflags2}")
|
||||
elseif(FLAG MATCHES "^-.*")
|
||||
string(APPEND ${FLAGS} " ${FLAG}")
|
||||
elseif(EXISTS ${FLAG})
|
||||
string(APPEND ${FLAGS} " ${FLAG}")
|
||||
else()
|
||||
string(APPEND ${FLAGS} " -l${FLAG}")
|
||||
endif()
|
||||
endforeach()
|
||||
endmacro()
|
||||
|
||||
function(target_flags FLAGS TARGET)
|
||||
set(_flags)
|
||||
append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "")
|
||||
append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D")
|
||||
append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ")
|
||||
append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ")
|
||||
append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "")
|
||||
append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "")
|
||||
# message("_flags: ${_flags}")
|
||||
set(${FLAGS} ${_flags} PARENT_SCOPE)
|
||||
endfunction()
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -23,9 +23,9 @@ template <typename... Wei,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -102,7 +102,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilda),
|
||||
@@ -114,28 +114,28 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
#if 1
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_pass_through_transform(C),
|
||||
@@ -143,7 +143,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_pass_through_transform(C),
|
||||
@@ -154,7 +154,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
|
||||
// output tensor
|
||||
// this add padding check
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
@@ -163,7 +163,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilda),
|
||||
@@ -175,7 +175,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
@@ -197,7 +197,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
#if 1
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
@@ -205,7 +205,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
@@ -215,7 +215,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
@@ -224,7 +224,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||
@@ -235,7 +235,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilda),
|
||||
@@ -256,7 +256,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(C),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -26,9 +26,9 @@ template <typename... Wei,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -106,7 +106,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
|
||||
// A: output tensor
|
||||
// this add padding check
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
@@ -115,7 +115,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilda),
|
||||
@@ -127,7 +127,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
@@ -149,7 +149,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
#if 1
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
@@ -157,7 +157,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
@@ -167,7 +167,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
#endif
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilda),
|
||||
@@ -179,28 +179,28 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
#if 1
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_pass_through_transform(C),
|
||||
@@ -208,7 +208,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_pass_through_transform(C),
|
||||
@@ -218,7 +218,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
#endif
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
@@ -227,7 +227,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||
@@ -238,7 +238,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilda),
|
||||
@@ -259,7 +259,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_pass_through_transform(C)),
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -18,9 +18,9 @@ template <typename... Wei,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_global_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
@@ -109,9 +109,9 @@ template <typename... Wei,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -126,9 +126,6 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
@@ -150,14 +147,14 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
||||
assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
@@ -167,15 +164,15 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_global_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
@@ -192,9 +189,9 @@ template <typename... Wei,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -209,9 +206,6 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
@@ -235,22 +229,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -18,9 +18,9 @@ template <typename... Wei,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
@@ -108,9 +108,9 @@ template <typename... Wei,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -125,9 +125,6 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
@@ -151,22 +148,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
||||
InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)),
|
||||
const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -20,9 +20,9 @@ template <typename... Wei,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -20,9 +20,9 @@ template <typename... Wei,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -23,9 +23,9 @@ template <typename... In,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -70,7 +70,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
@@ -79,7 +79,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
@@ -89,36 +89,36 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc =
|
||||
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -24,9 +24,9 @@ template <typename... Wei,
|
||||
typename C0Type>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -68,15 +68,15 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
const auto C1 = C / C0;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
|
||||
const auto wei_gk0_gm0_gm1_gk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
@@ -85,7 +85,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(C0, C1)),
|
||||
@@ -94,7 +94,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
|
||||
make_pass_through_transform(N0),
|
||||
@@ -105,17 +105,17 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
|
||||
// output tensor
|
||||
const auto out_n_k_howo_grid_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo));
|
||||
|
||||
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
out_n_k_howo_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(Ho * Wo)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
const auto out_n0_n1_1_k_howo_grid_desc =
|
||||
transform_tensor_descriptor(out_n_k_howo_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(Ho * Wo)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_1_k_howo_grid_desc,
|
||||
make_tuple(make_pass_through_transform(I1),
|
||||
make_pass_through_transform(K),
|
||||
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
|
||||
__host__ __device__ constexpr auto make_cluster_descriptor_v2(
|
||||
__host__ __device__ constexpr auto make_cluster_descriptor(
|
||||
const Lengths& lengths,
|
||||
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP
|
||||
#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP
|
||||
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
|
||||
#define CK_MULTI_INDEX_TRANSFORM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index.hpp"
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck {
|
||||
|
||||
template <typename LowLength>
|
||||
struct DynamicPassThrough
|
||||
struct PassThrough
|
||||
{
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
@@ -16,9 +16,9 @@ struct DynamicPassThrough
|
||||
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr DynamicPassThrough() = default;
|
||||
__host__ __device__ constexpr PassThrough() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicPassThrough(const LowLength& low_length)
|
||||
__host__ __device__ constexpr PassThrough(const LowLength& low_length)
|
||||
: up_lengths_{make_tuple(low_length)}
|
||||
{
|
||||
}
|
||||
@@ -82,33 +82,36 @@ struct DynamicPassThrough
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicPassThrough, ");
|
||||
printf("PassThrough, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
|
||||
struct DynamicPad
|
||||
template <typename LowLength,
|
||||
typename LeftPadLength,
|
||||
typename RightPadLength,
|
||||
bool SkipIsValidCheck = false>
|
||||
struct Pad
|
||||
{
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{} + RightPad{}));
|
||||
using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
|
||||
|
||||
UpLengths up_lengths_;
|
||||
LeftPad left_pad_;
|
||||
RightPad right_pad_;
|
||||
LeftPadLength left_pad_length_;
|
||||
RightPadLength right_pad_length_;
|
||||
|
||||
__host__ __device__ constexpr DynamicPad() = default;
|
||||
__host__ __device__ constexpr Pad() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicPad(const LowLength& low_length,
|
||||
const LeftPad& left_pad,
|
||||
const RightPad& right_pad)
|
||||
: up_lengths_{make_tuple(low_length + left_pad + right_pad)},
|
||||
left_pad_{left_pad},
|
||||
right_pad_{right_pad}
|
||||
__host__ __device__ constexpr Pad(const LowLength& low_length,
|
||||
const LeftPadLength& left_pad_length,
|
||||
const RightPadLength& right_pad_length)
|
||||
: up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
|
||||
left_pad_length_{left_pad_length},
|
||||
right_pad_length_{right_pad_length}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -125,7 +128,7 @@ struct DynamicPad
|
||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_;
|
||||
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
@@ -161,45 +164,46 @@ struct DynamicPad
|
||||
__host__ __device__ constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
||||
{
|
||||
return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) &&
|
||||
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_));
|
||||
return SkipIsValidCheck ||
|
||||
((idx_up[Number<0>{}] >= left_pad_length_) &&
|
||||
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return is_known_at_compile_time<UpLengths>::value &&
|
||||
is_known_at_compile_time<LeftPad>::value &&
|
||||
is_known_at_compile_time<RightPad>::value;
|
||||
is_known_at_compile_time<LeftPadLength>::value &&
|
||||
is_known_at_compile_time<RightPadLength>::value;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicPad, ");
|
||||
printf("Pad, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("left_pad_ %d", index_t{left_pad_});
|
||||
printf("right_pad_ %d", index_t{right_pad_});
|
||||
printf("left_pad_length %d", index_t{left_pad_length_});
|
||||
printf("right_pad_length %d", index_t{right_pad_length_});
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename LeftPad, bool SkipIsValidCheck = false>
|
||||
struct DynamicLeftPad
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||
struct LeftPad
|
||||
{
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{}));
|
||||
using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
|
||||
|
||||
UpLengths up_lengths_;
|
||||
LeftPad left_pad_;
|
||||
LeftPadLength left_pad_length_;
|
||||
|
||||
__host__ __device__ constexpr DynamicLeftPad() = default;
|
||||
__host__ __device__ constexpr LeftPad() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicLeftPad(const LowLength& low_length,
|
||||
const LeftPad& left_pad)
|
||||
: up_lengths_{make_tuple(low_length + left_pad)}, left_pad_{left_pad}
|
||||
__host__ __device__ constexpr LeftPad(const LowLength& low_length,
|
||||
const LeftPadLength& left_pad_length)
|
||||
: up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -216,7 +220,7 @@ struct DynamicLeftPad
|
||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_;
|
||||
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff,
|
||||
@@ -252,45 +256,45 @@ struct DynamicLeftPad
|
||||
__host__ __device__ constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
||||
{
|
||||
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_);
|
||||
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return is_known_at_compile_time<UpLengths>::value &&
|
||||
is_known_at_compile_time<LeftPad>::value;
|
||||
is_known_at_compile_time<LeftPadLength>::value;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicLeftPad, ");
|
||||
printf("LeftPad, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("left_pad_ %d", index_t{left_pad_});
|
||||
printf("left_pad_length_ %d", index_t{left_pad_length_});
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename LowLength, typename RightPad, bool SkipIsValidCheck = false>
|
||||
struct DynamicRightPad
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||
struct RightPad
|
||||
{
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using UpLengths = decltype(make_tuple(LowLength{} + RightPad{}));
|
||||
using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
|
||||
|
||||
UpLengths up_lengths_;
|
||||
LowLength low_length_;
|
||||
RightPad right_pad_;
|
||||
RightPadLength right_pad_length_;
|
||||
|
||||
__host__ __device__ constexpr DynamicRightPad() = default;
|
||||
__host__ __device__ constexpr RightPad() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicRightPad(const LowLength& low_length,
|
||||
const RightPad& right_pad)
|
||||
: up_lengths_{make_tuple(low_length + right_pad)},
|
||||
__host__ __device__ constexpr RightPad(const LowLength& low_length,
|
||||
const RightPadLength& right_pad_length)
|
||||
: up_lengths_{make_tuple(low_length + right_pad_length)},
|
||||
low_length_{low_length},
|
||||
right_pad_{right_pad}
|
||||
right_pad_length_{right_pad_length}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -350,17 +354,17 @@ struct DynamicRightPad
|
||||
{
|
||||
return is_known_at_compile_time<UpLengths>::value &&
|
||||
is_known_at_compile_time<LowLength>::value &&
|
||||
is_known_at_compile_time<RightPad>::value;
|
||||
is_known_at_compile_time<RightPadLength>::value;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicRightPad, ");
|
||||
printf("RightPad, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("low_length_ %d", index_t{low_length_});
|
||||
printf("left_pad_ %d", index_t{right_pad_});
|
||||
printf("left_pad_length_ %d", index_t{right_pad_length_});
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
@@ -373,8 +377,8 @@ struct DynamicRightPad
|
||||
// at compile-time
|
||||
template <typename UpLengths,
|
||||
typename Coefficients,
|
||||
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
struct DynamicEmbed
|
||||
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t NDimUp = UpLengths::Size();
|
||||
|
||||
@@ -384,10 +388,10 @@ struct DynamicEmbed
|
||||
UpLengths up_lengths_;
|
||||
Coefficients coefficients_;
|
||||
|
||||
__host__ __device__ constexpr DynamicEmbed() = default;
|
||||
__host__ __device__ constexpr Embed() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicEmbed(const UpLengths& up_lengths,
|
||||
const Coefficients& coefficients)
|
||||
__host__ __device__ constexpr Embed(const UpLengths& up_lengths,
|
||||
const Coefficients& coefficients)
|
||||
: up_lengths_{up_lengths}, coefficients_{coefficients}
|
||||
{
|
||||
}
|
||||
@@ -458,7 +462,7 @@ struct DynamicEmbed
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicEmbed, ");
|
||||
printf("Embed, ");
|
||||
printf("up_lengths_ ");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("coefficients_ ");
|
||||
@@ -470,30 +474,30 @@ struct DynamicEmbed
|
||||
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
|
||||
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
|
||||
template <typename LowLengths>
|
||||
struct DynamicMerge_v1_carry_check
|
||||
struct Merge_v1_carry_check
|
||||
{
|
||||
static constexpr index_t NDimLow = LowLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<NDimLow>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using LowLengthsScan = decltype(
|
||||
container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{}));
|
||||
using LowLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
|
||||
|
||||
LowLengths low_lengths_;
|
||||
LowLengthsScan low_lengths_scan_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v1_carry_check() = default;
|
||||
__host__ __device__ constexpr Merge_v1_carry_check() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v1_carry_check(const LowLengths& low_lengths)
|
||||
__host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_scan_{
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
|
||||
}
|
||||
@@ -555,7 +559,7 @@ struct DynamicMerge_v1_carry_check
|
||||
LowerIndex idx_low_length_minus_idx_diff_low_const;
|
||||
LowerIndex idx_low_length_plus_idx_diff_low_const;
|
||||
|
||||
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
index_t tmp = idx_diff_up[Number<0>{}];
|
||||
|
||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||
@@ -698,7 +702,7 @@ struct DynamicMerge_v1_carry_check
|
||||
LowerIndex idx_low_length_minus_idx_diff_low_const;
|
||||
LowerIndex idx_low_length_plus_idx_diff_low_const;
|
||||
|
||||
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
index_t tmp = idx_diff_up[Number<0>{}];
|
||||
|
||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||
@@ -838,7 +842,7 @@ struct DynamicMerge_v1_carry_check
|
||||
// very expensive.
|
||||
LowerIndex idx_diff_low_const;
|
||||
|
||||
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
index_t tmp = idx_diff_up[Number<0>{}];
|
||||
|
||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||
@@ -981,7 +985,7 @@ struct DynamicMerge_v1_carry_check
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicMerge_v1_carry_check, ");
|
||||
printf("Merge_v1_carry_check, ");
|
||||
printf("low_lengths_ ");
|
||||
print_multi_index(low_lengths_);
|
||||
printf("low_lengths_scan_ ");
|
||||
@@ -1025,7 +1029,7 @@ struct lambda_merge_generate_MagicDivision_calculate_magic_shift
|
||||
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
|
||||
// non-negative.
|
||||
template <typename LowLengths>
|
||||
struct DynamicMerge_v2_magic_division
|
||||
struct Merge_v2_magic_division
|
||||
{
|
||||
static constexpr index_t NDimLow = LowLengths::Size();
|
||||
|
||||
@@ -1033,7 +1037,7 @@ struct DynamicMerge_v2_magic_division
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
|
||||
|
||||
using LowLengthsMagicDivisorMultipiler = decltype(
|
||||
generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengths>{},
|
||||
@@ -1048,9 +1052,9 @@ struct DynamicMerge_v2_magic_division
|
||||
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v2_magic_division() = default;
|
||||
__host__ __device__ constexpr Merge_v2_magic_division() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v2_magic_division(const LowLengths& low_lengths)
|
||||
__host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_magic_divisor_multiplier_{generate_tuple(
|
||||
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
|
||||
@@ -1058,7 +1062,7 @@ struct DynamicMerge_v2_magic_division
|
||||
low_lengths_magic_divisor_shift_{generate_tuple(
|
||||
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); },
|
||||
Number<NDimLow>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
|
||||
}
|
||||
@@ -1151,7 +1155,7 @@ struct DynamicMerge_v2_magic_division
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicMerge_v2_magic_division, ");
|
||||
printf("Merge_v2_magic_division, ");
|
||||
printf("low_lengths_ ");
|
||||
print_multi_index(low_lengths_);
|
||||
printf("low_lengths_magic_divisor_multiplier_ ");
|
||||
@@ -1177,18 +1181,18 @@ struct DynamicMerge_v2_magic_division
|
||||
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
|
||||
// non-negative.
|
||||
template <typename LowLengths>
|
||||
struct DynamicMerge_v2r2_magic_division
|
||||
struct Merge_v2r2_magic_division
|
||||
{
|
||||
static constexpr index_t NDimLow = LowLengths::Size();
|
||||
|
||||
using LowerIndex = MultiIndex<NDimLow>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
|
||||
using LowLengthsScan = decltype(
|
||||
container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{}));
|
||||
using LowLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{})));
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));
|
||||
|
||||
using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple(
|
||||
lambda_merge_generate_MagicDivision_calculate_magic_multiplier<LowLengthsScan>{},
|
||||
@@ -1204,19 +1208,19 @@ struct DynamicMerge_v2r2_magic_division
|
||||
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_;
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v2r2_magic_division() = default;
|
||||
__host__ __device__ constexpr Merge_v2r2_magic_division() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicMerge_v2r2_magic_division(const LowLengths& low_lengths)
|
||||
__host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_scan_{
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
|
||||
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
|
||||
low_lengths_scan_magic_divisor_multiplier_{generate_tuple(
|
||||
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); },
|
||||
Number<NDimLow>{})},
|
||||
low_lengths_scan_magic_divisor_shift_{generate_tuple(
|
||||
[&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); },
|
||||
Number<NDimLow>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))}
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
|
||||
}
|
||||
@@ -1308,7 +1312,7 @@ struct DynamicMerge_v2r2_magic_division
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicMerge_v2r2_magic_division, ");
|
||||
printf("Merge_v2r2_magic_division, ");
|
||||
printf("low_lengths_ ");
|
||||
print_multi_index(low_lengths_);
|
||||
printf("low_lengths_scan ");
|
||||
@@ -1324,7 +1328,7 @@ struct DynamicMerge_v2r2_magic_division
|
||||
};
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||
struct DynamicUnMerge
|
||||
struct UnMerge
|
||||
{
|
||||
static constexpr index_t NDimUp = UpLengths::Size();
|
||||
|
||||
@@ -1332,17 +1336,17 @@ struct DynamicUnMerge
|
||||
using UpperIndex = MultiIndex<NDimUp>;
|
||||
|
||||
using UpLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies_v2{}, Number<1>{}));
|
||||
decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number<1>{}));
|
||||
|
||||
UpLengths up_lengths_;
|
||||
UpLengthsScan up_lengths_scan_;
|
||||
|
||||
__host__ __device__ constexpr DynamicUnMerge() = default;
|
||||
__host__ __device__ constexpr UnMerge() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicUnMerge(const UpLengths& up_lengths)
|
||||
__host__ __device__ constexpr UnMerge(const UpLengths& up_lengths)
|
||||
: up_lengths_{up_lengths},
|
||||
up_lengths_scan_{
|
||||
container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})}
|
||||
container_reverse_exclusive_scan(up_lengths, math::multiplies{}, Number<1>{})}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -1414,7 +1418,7 @@ struct DynamicUnMerge
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicUnMerge, ");
|
||||
printf("UnMerge, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("up_lengths_scan_");
|
||||
@@ -1424,13 +1428,13 @@ struct DynamicUnMerge
|
||||
};
|
||||
|
||||
template <typename LowerIndex>
|
||||
struct DynamicFreeze
|
||||
struct Freeze
|
||||
{
|
||||
LowerIndex low_idx_;
|
||||
|
||||
__host__ __device__ constexpr DynamicFreeze() = default;
|
||||
__host__ __device__ constexpr Freeze() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicFreeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
|
||||
__host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
|
||||
|
||||
@@ -1483,22 +1487,22 @@ struct DynamicFreeze
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("DynamicFreeze");
|
||||
printf("Freeze");
|
||||
printf("low_idx_ %d", index_t{low_idx_});
|
||||
}
|
||||
};
|
||||
|
||||
// Insert a dangling upper dimension without lower dimension
|
||||
template <typename UpperLength>
|
||||
struct DynamicInsert
|
||||
struct Insert
|
||||
{
|
||||
using UpLengths = decltype(make_tuple(UpperLength{}));
|
||||
|
||||
UpLengths up_lengths_;
|
||||
|
||||
__host__ __device__ constexpr DynamicInsert() = default;
|
||||
__host__ __device__ constexpr Insert() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicInsert(const UpperLength& up_length)
|
||||
__host__ __device__ constexpr Insert(const UpperLength& up_length)
|
||||
: up_lengths_{make_tuple(up_length)}
|
||||
{
|
||||
}
|
||||
@@ -1550,13 +1554,13 @@ struct DynamicInsert
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("DynamicInsert");
|
||||
printf("Insert");
|
||||
print_multi_index(up_lengths_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename VectorSize, typename UpLength>
|
||||
struct DynamicVectorize
|
||||
struct Vectorize
|
||||
{
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
@@ -1566,10 +1570,10 @@ struct DynamicVectorize
|
||||
UpLengths up_lengths_;
|
||||
VectorSize vector_size_;
|
||||
|
||||
__host__ __device__ constexpr DynamicVectorize() = default;
|
||||
__host__ __device__ constexpr Vectorize() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size,
|
||||
const UpLength& up_length)
|
||||
__host__ __device__ constexpr Vectorize(const VectorSize& vector_size,
|
||||
const UpLength& up_length)
|
||||
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
|
||||
{
|
||||
}
|
||||
@@ -1633,7 +1637,7 @@ struct DynamicVectorize
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicVectorize, ");
|
||||
printf("Vectorize, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("}");
|
||||
@@ -1641,7 +1645,7 @@ struct DynamicVectorize
|
||||
};
|
||||
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
struct DynamicSlice
|
||||
struct Slice
|
||||
{
|
||||
using LowerIndex = MultiIndex<1>;
|
||||
using UpperIndex = MultiIndex<1>;
|
||||
@@ -1652,11 +1656,11 @@ struct DynamicSlice
|
||||
SliceBegin slice_begin_;
|
||||
SliceEnd slice_end_;
|
||||
|
||||
__host__ __device__ constexpr DynamicSlice() = default;
|
||||
__host__ __device__ constexpr Slice() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicSlice(const LowLength&,
|
||||
const SliceBegin& slice_begin,
|
||||
const SliceEnd& slice_end)
|
||||
__host__ __device__ constexpr Slice(const LowLength&,
|
||||
const SliceBegin& slice_begin,
|
||||
const SliceEnd& slice_end)
|
||||
: up_lengths_{make_tuple(slice_end - slice_begin)},
|
||||
slice_begin_{slice_begin},
|
||||
slice_end_{slice_end}
|
||||
@@ -1724,7 +1728,7 @@ struct DynamicSlice
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicSlice, ");
|
||||
printf("Slice, ");
|
||||
printf("up_lengths_");
|
||||
print_multi_index(up_lengths_);
|
||||
printf("slice_begin_ %d", index_t{slice_begin_});
|
||||
@@ -1,15 +1,15 @@
|
||||
#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
||||
#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
||||
#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
||||
#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform.hpp"
|
||||
#include "multi_index_transform.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename LowLength>
|
||||
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
|
||||
{
|
||||
return DynamicPassThrough<LowLength>{low_length};
|
||||
return PassThrough<LowLength>{low_length};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
|
||||
@@ -19,47 +19,46 @@ make_pad_transform(const LowLength& low_length,
|
||||
const RightPad& right_pad,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return DynamicPad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{
|
||||
low_length, left_pad, right_pad};
|
||||
return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename LeftPad, bool SkipIsValidCheck = false>
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||
__host__ __device__ constexpr auto make_left_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const LeftPad& left_pad,
|
||||
const LeftPadLength& left_pad,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return DynamicLeftPad<LowLength, LeftPad, SkipIsValidCheck>{low_length, left_pad};
|
||||
return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPad, bool SkipIsValidCheck>
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||
__host__ __device__ constexpr auto make_right_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const RightPad& right_pad,
|
||||
const RightPadLength& right_pad,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return DynamicRightPad<LowLength, RightPad, SkipIsValidCheck>{low_length, right_pad};
|
||||
return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
|
||||
}
|
||||
|
||||
template <typename UpLengths,
|
||||
typename Coefficients,
|
||||
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
|
||||
const Coefficients& coefficients)
|
||||
{
|
||||
return DynamicEmbed<UpLengths, Coefficients>{up_lengths, coefficients};
|
||||
return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
return DynamicMerge_v1_carry_check<LowLengths>{low_lengths};
|
||||
return Merge_v1_carry_check<LowLengths>{low_lengths};
|
||||
#else
|
||||
#if 1
|
||||
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
|
||||
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||
#else
|
||||
return DynamicMerge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
@@ -68,7 +67,7 @@ template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
|
||||
{
|
||||
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
|
||||
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
|
||||
@@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform(
|
||||
const UpLengths& up_lengths,
|
||||
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
|
||||
{
|
||||
return DynamicUnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
|
||||
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
|
||||
}
|
||||
|
||||
template <typename LowerIndex>
|
||||
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
|
||||
{
|
||||
return DynamicFreeze<LowerIndex>{low_idx};
|
||||
return Freeze<LowerIndex>{low_idx};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
@@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
|
||||
const SliceBegin& slice_begin,
|
||||
const SliceEnd& slice_end)
|
||||
{
|
||||
return DynamicSlice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
|
||||
return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
|
||||
}
|
||||
|
||||
template <typename VectorSize, typename UpLength>
|
||||
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
|
||||
const UpLength& up_length)
|
||||
{
|
||||
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length};
|
||||
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_TENSOR_ADAPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -64,7 +64,7 @@ struct TensorAdaptor
|
||||
Number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of Number and index_t
|
||||
return container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
|
||||
return container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
|
||||
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Xs,
|
||||
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform.hpp"
|
||||
#include "multi_index_transform.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t NDimHidden, typename VisibleDimensionIds>
|
||||
struct DynamicTensorCoordinate;
|
||||
struct TensorCoordinate;
|
||||
|
||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||
struct DynamicTensorCoordinateIterator;
|
||||
struct TensorCoordinateStep;
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
|
||||
@@ -21,7 +21,7 @@ template <typename Transforms,
|
||||
typename UpperDimensionIdss,
|
||||
typename VisibleDimensionIds,
|
||||
typename ElementSpaceSize>
|
||||
struct DynamicTensorDescriptor
|
||||
struct TensorDescriptor
|
||||
{
|
||||
// TODO make these private
|
||||
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
|
||||
@@ -69,7 +69,7 @@ struct DynamicTensorDescriptor
|
||||
Number<ndim_visible_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of Number and index_t
|
||||
return container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
|
||||
return container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor
|
||||
|
||||
using VisibleIndex = MultiIndex<ndim_visible_>;
|
||||
using HiddenIndex = MultiIndex<ndim_hidden_>;
|
||||
using Coordinate = DynamicTensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
|
||||
using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
|
||||
|
||||
// may be index_t or Number<>
|
||||
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr DynamicTensorDescriptor() = default;
|
||||
__host__ __device__ constexpr TensorDescriptor() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms,
|
||||
ElementSpaceSize element_space_size)
|
||||
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
|
||||
ElementSpaceSize element_space_size)
|
||||
: transforms_{transforms},
|
||||
element_size_{InitializeElementSize(transforms)},
|
||||
element_space_size_{element_space_size}
|
||||
@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
|
||||
{
|
||||
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
|
||||
|
||||
return make_dynamic_tensor_coordinate(*this, idx).GetOffset();
|
||||
return make_tensor_coordinate(*this, idx).GetOffset();
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("DynamicTensorDescriptor, ");
|
||||
printf("TensorDescriptor, ");
|
||||
static_for<0, ntransform_, 1>{}([&](auto i) {
|
||||
printf("transforms: ");
|
||||
transforms_[i].Print();
|
||||
@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
|
||||
};
|
||||
|
||||
template <index_t NDimHidden, typename VisibleDimensionIds>
|
||||
struct DynamicTensorCoordinate
|
||||
struct TensorCoordinate
|
||||
{
|
||||
// TODO make these private
|
||||
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
|
||||
@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
|
||||
using VisibleIndex = MultiIndex<ndim_visible_>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr DynamicTensorCoordinate() = default;
|
||||
__host__ __device__ constexpr TensorCoordinate() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicTensorCoordinate(const HiddenIndex& idx_hidden)
|
||||
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
|
||||
: idx_hidden_{idx_hidden}
|
||||
{
|
||||
}
|
||||
@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate
|
||||
};
|
||||
|
||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||
struct DynamicTensorCoordinateIterator
|
||||
struct TensorCoordinateStep
|
||||
{
|
||||
// TODO make these private
|
||||
using VisibleIndex = MultiIndex<NDimVisible>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr DynamicTensorCoordinateIterator() = default;
|
||||
__host__ __device__ constexpr TensorCoordinateStep() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicTensorCoordinateIterator(
|
||||
const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms)
|
||||
__host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
|
||||
const MultiIndex<NTransform>& do_transforms)
|
||||
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
|
||||
{
|
||||
}
|
||||
@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator
|
||||
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor, and to put it outside the scope where it is used
|
||||
// (transform_dynamic_tensor_descriptor) because template cannot be defined inside a function
|
||||
// (transform_tensor_descriptor) because template cannot be defined inside a function
|
||||
// template
|
||||
template <typename NewTransforms>
|
||||
struct lambda_get_up_dim_num
|
||||
@@ -301,10 +301,10 @@ template <typename OldTensorDescriptor,
|
||||
typename NewLowerDimensionOldVisibleIdss,
|
||||
typename NewUpperDimensionNewVisibleIdss>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss,
|
||||
NewUpperDimensionNewVisibleIdss)
|
||||
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss,
|
||||
NewUpperDimensionNewVisibleIdss)
|
||||
{
|
||||
// sanity check
|
||||
{
|
||||
@@ -376,17 +376,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
|
||||
const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
|
||||
|
||||
return DynamicTensorDescriptor<remove_cv_t<decltype(all_transforms)>,
|
||||
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{all_transforms,
|
||||
element_space_size};
|
||||
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
|
||||
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{all_transforms,
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename VisibleIndex>
|
||||
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
const VisibleIndex& idx_visible)
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
const VisibleIndex& idx_visible)
|
||||
{
|
||||
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
|
||||
return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
|
||||
}
|
||||
|
||||
// UpdateLowerIndexHack: Sequence<...>
|
||||
// HACK: control UpdateLowerIndex
|
||||
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
|
||||
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
|
||||
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack)
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||
const VisibleIndex& idx_diff_visible,
|
||||
UpdateLowerIndexHack)
|
||||
{
|
||||
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
|
||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||
});
|
||||
|
||||
return DynamicTensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{
|
||||
idx_diff_visible, do_transforms};
|
||||
return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
|
||||
do_transforms};
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename VisibleIndex>
|
||||
__host__ __device__ constexpr auto
|
||||
make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible)
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||
const VisibleIndex& idx_diff_visible)
|
||||
{
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
return make_tensor_coordinate_step(
|
||||
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator>
|
||||
__host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
||||
const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator)
|
||||
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
|
||||
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
TensorCoord& coord,
|
||||
const TensorCoordStep& coord_step)
|
||||
{
|
||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// initialize visible index diff
|
||||
set_container_subset(idx_diff_hidden,
|
||||
TensorDesc::GetVisibleDimensionIds(),
|
||||
coord_iterator.GetVisibleIndexDiff());
|
||||
set_container_subset(
|
||||
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
|
||||
|
||||
// this is what needs to be updated
|
||||
auto& idx_hidden = coord.GetHiddenIndex();
|
||||
@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
||||
auto idx_hidden_pick_visible =
|
||||
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
|
||||
|
||||
idx_hidden_pick_visible += coord_iterator.GetIndexDiff();
|
||||
idx_hidden_pick_visible += coord_step.GetIndexDiff();
|
||||
|
||||
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
|
||||
|
||||
// update rest of hidden index
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
if(coord_iterator.do_transforms_[itran])
|
||||
if(coord_step.do_transforms_[itran])
|
||||
{
|
||||
const auto& tran = tensor_desc.GetTransforms().At(itran);
|
||||
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
||||
@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
||||
|
||||
MultiIndex<dims_low.Size()> idx_diff_low;
|
||||
|
||||
// HACK: control UpdateLowerIndex for DynamicMerge using hack
|
||||
constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran);
|
||||
// HACK: control UpdateLowerIndex for Merge using hack
|
||||
constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
|
||||
|
||||
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
|
||||
|
||||
@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
|
||||
}
|
||||
|
||||
template <typename TensorDesc>
|
||||
using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate(
|
||||
using TensorCoordinate_t = decltype(make_tensor_coordinate(
|
||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||
|
||||
template <typename TensorDesc>
|
||||
using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator(
|
||||
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
|
||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||
|
||||
} // namespace ck
|
||||
@@ -1,9 +1,9 @@
|
||||
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -37,10 +37,9 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
|
||||
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
__host__ __device__ constexpr auto
|
||||
make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
|
||||
const Tuple<Strides...>& strides)
|
||||
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths,
|
||||
const Tuple<Strides...>& strides)
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
|
||||
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
|
||||
#endif
|
||||
|
||||
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{transforms,
|
||||
element_space_size};
|
||||
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{transforms,
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
// Lengths... can be:
|
||||
@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
|
||||
// 2) Number<>, which is known at compile-time
|
||||
template <typename... Lengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
|
||||
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
@@ -101,19 +100,19 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
|
||||
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{transforms,
|
||||
element_space_size};
|
||||
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{transforms,
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
template <typename... Lengths, typename Align>
|
||||
__host__ __device__ constexpr auto
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align)
|
||||
make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -134,7 +133,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
|
||||
else
|
||||
{
|
||||
return container_reduce(lengths,
|
||||
math::multiplies_v2{},
|
||||
math::multiplies{},
|
||||
Number<stride_n_minus_2>{},
|
||||
i + I1,
|
||||
Number<N - 1>{},
|
||||
@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
|
||||
},
|
||||
Number<N>{});
|
||||
|
||||
return make_dynamic_naive_tensor_descriptor_v2(lengths, strides);
|
||||
return make_naive_tensor_descriptor(lengths, strides);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -22,24 +22,24 @@ namespace ck {
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AKMBlockDesc,
|
||||
typename BKNBlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
|
||||
BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
template <
|
||||
index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AKMBlockDesc,
|
||||
typename BKNBlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
static constexpr index_t N0 = N / N1;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc)
|
||||
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
|
||||
{
|
||||
const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
|
||||
AKMBlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
|
||||
@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc)
|
||||
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
|
||||
{
|
||||
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
|
||||
BKNBlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
|
||||
@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
|
||||
__device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
|
||||
private:
|
||||
// A[K, M0, M1]
|
||||
static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
|
||||
|
||||
// B[K, N0, N1]
|
||||
static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_k_m0_m1_block_desc_),
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThreadM11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M11,
|
||||
1>;
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_k_m0_m1_block_desc_),
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThreadM11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M11,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_k_n0_n1_block_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThreadN11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N11,
|
||||
1>;
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_k_n0_n1_block_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThreadN11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N11,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -38,9 +38,9 @@ template <index_t BlockSize,
|
||||
// BM10BN10ThreadClusterBN101, ...>
|
||||
index_t AThreadCopyScalarPerVector_BM11,
|
||||
index_t BThreadCopyScalarPerVector_BN11,
|
||||
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
|
||||
{
|
||||
const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor(
|
||||
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
|
||||
a_block_desc_bk0_bm_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
|
||||
@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
|
||||
{
|
||||
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor(
|
||||
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
|
||||
b_block_desc_bk0_bn_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
|
||||
@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
private:
|
||||
// A[BK0, BM0, BM1, BK1]
|
||||
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
|
||||
|
||||
// B[BK0, BN0, BN1, BK1]
|
||||
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
|
||||
@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
||||
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
|
||||
|
||||
@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
// HACK: fix this @Jing Zhang
|
||||
static constexpr index_t KPerThreadSubC = 4;
|
||||
|
||||
static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
|
||||
|
||||
static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||
|
||||
static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
BlockMatrixA,
|
||||
decltype(a_thread_mtx_),
|
||||
Sequence<EPerThreadLoop, KPerThreadSubC>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmADataPerRead_K,
|
||||
1>;
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
BlockMatrixA,
|
||||
decltype(a_thread_mtx_),
|
||||
Sequence<EPerThreadLoop, KPerThreadSubC>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmADataPerRead_K,
|
||||
1>;
|
||||
|
||||
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
|
||||
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
|
||||
@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
|
||||
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
|
||||
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
|
||||
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
|
||||
|
||||
@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
|
||||
@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||
|
||||
// thread A buffer for GEMM
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()>
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
|
||||
a_thread_buf;
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "xdlops_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
@@ -193,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, MRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
K1,
|
||||
1>;
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, MRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
K1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, NRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
K1,
|
||||
1>;
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, NRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
K1,
|
||||
1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
@@ -490,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // K1,
|
||||
1>;
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // K1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // K1,
|
||||
1>;
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // K1,
|
||||
1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
|
||||
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
@@ -33,16 +33,16 @@ template <index_t BlockSize,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
struct BlockwiseTensorSliceTransfer_v4
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin)
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin)
|
||||
: threadwise_transfer_(
|
||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||
|
||||
@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
}
|
||||
}
|
||||
|
||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& step,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(
|
||||
src_desc, step, src_move_slice_window_iterator_hack);
|
||||
src_desc, step, src_move_slice_window_step_hack);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,25 +143,25 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
@@ -1,18 +1,18 @@
|
||||
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
@@ -31,17 +31,16 @@ template <index_t BlockSize,
|
||||
typename DstVectorTensorContiguousDimOrder,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
||||
struct BlockwiseTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin)
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin)
|
||||
: threadwise_transfer_(
|
||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||
|
||||
@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
||||
}
|
||||
}
|
||||
|
||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& step,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(
|
||||
src_desc, step, src_move_slice_window_iterator_hack);
|
||||
src_desc, step, src_move_slice_window_step_hack);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,23 +131,23 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v3r1<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorTensorLengths,
|
||||
DstVectorTensorLengths,
|
||||
SrcVectorTensorContiguousDimOrder,
|
||||
DstVectorTensorContiguousDimOrder,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorTensorLengths,
|
||||
DstVectorTensorLengths,
|
||||
SrcVectorTensorContiguousDimOrder,
|
||||
DstVectorTensorContiguousDimOrder,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
@@ -1,14 +1,14 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -25,7 +25,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_contraction_dlops_v1r2(
|
||||
kernel_contraction_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -84,12 +84,12 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -110,17 +110,15 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
const auto GM10 = GM1 / GM11;
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor(
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GM0),
|
||||
@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor(
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GN0),
|
||||
@@ -250,16 +248,16 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies_v2{}, I1) *
|
||||
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies{}, I1) *
|
||||
BM1PerThreadBM11>{};
|
||||
constexpr auto BN1 =
|
||||
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies_v2{}, I1) *
|
||||
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies{}, I1) *
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
make_tuple(make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
|
||||
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_merge_transform(make_tuple(GM0, GM11)),
|
||||
@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor(
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
|
||||
c_gm10_bm_gn10_bn_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||
@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// A matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
||||
|
||||
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
|
||||
@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
@@ -457,9 +453,8 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
|
||||
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
@@ -475,9 +470,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
|
||||
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
|
||||
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
|
||||
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
|
||||
@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
||||
c_thread_buf,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1,14 +1,14 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP
|
||||
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -26,7 +26,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_dlops_v1r2(
|
||||
kernel_gemm_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -68,28 +68,27 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
// first cast void CONSTANT void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k_m0_m1_grid_desc =
|
||||
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc =
|
||||
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
|
||||
const auto a_k_m0_m1_grid_desc = *reinterpret_cast<const AKM0M1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc));
|
||||
const auto b_k_n0_n1_grid_desc = *reinterpret_cast<const BKN0N1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc));
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
@@ -146,12 +145,12 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
|
||||
a_k_m_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
|
||||
b_k_n_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
@@ -352,75 +351,75 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k_m0_m1_grid_desc),
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k_m0_m1_grid_desc,
|
||||
make_multi_index(0, im0, 0),
|
||||
a_k_m0_m1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k_m0_m1_grid_desc),
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k_m0_m1_grid_desc,
|
||||
make_multi_index(0, im0, 0),
|
||||
a_k_m0_m1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k_n0_n1_grid_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k_n0_n1_grid_desc,
|
||||
make_multi_index(0, in0, 0),
|
||||
b_k_n0_n1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k_n0_n1_grid_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k_n0_n1_grid_desc,
|
||||
make_multi_index(0, in0, 0),
|
||||
b_k_n0_n1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
@@ -447,9 +446,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
@@ -465,9 +463,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
|
||||
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
|
||||
AGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
|
||||
AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
|
||||
BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M11 =
|
||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
||||
constexpr index_t N11 =
|
||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
||||
|
||||
constexpr index_t M10 = MPerBlockM1 / M11;
|
||||
constexpr index_t N10 = NPerBlockN1 / N11;
|
||||
|
||||
constexpr index_t M111 = M1PerThreadM111;
|
||||
constexpr index_t N111 = N1PerThreadN111;
|
||||
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1,14 +1,14 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP
|
||||
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
|
||||
#define CK_GRIDWISE_GEMM_V1R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -26,7 +26,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_dlops_v1r3(
|
||||
kernel_gemm_dlops_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -68,28 +68,27 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_dlops_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
// first cast void CONSTANT void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0M0M1K1GridDesc*>((const void*)p_a_k0_m0_m1_k1_grid_desc);
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0N0N1K1GridDesc*>((const void*)p_b_k0_n0_n1_k1_grid_desc);
|
||||
const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast<const AK0M0M1K1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc));
|
||||
const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast<const BK0N0N1K1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc));
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
@@ -142,12 +141,12 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
|
||||
}
|
||||
@@ -231,13 +230,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k0_m_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k0_m_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_k0_m0_m1_k1_grid_desc;
|
||||
}
|
||||
@@ -251,13 +250,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k0_n_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_k0_n_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_k0_n0_n1_k1_grid_desc;
|
||||
}
|
||||
@@ -275,16 +274,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
|
||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
|
||||
M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
|
||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
|
||||
@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
||||
@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
||||
@@ -453,9 +452,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
@@ -471,9 +469,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto M11 =
|
||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
|
||||
M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr index_t M10 = MPerBlockM1 / M11;
|
||||
constexpr index_t N10 = NPerBlockN1 / N11;
|
||||
|
||||
constexpr index_t M111 = M1PerThreadM111;
|
||||
constexpr index_t N111 = N1PerThreadN111;
|
||||
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -1,12 +1,12 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP
|
||||
#ifndef CK_GRIDWISE_GEMM_V2_HPP
|
||||
#define CK_GRIDWISE_GEMM_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "blockwise_gemm_dlops_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -42,12 +42,12 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalIteratorHacks,
|
||||
typename BGlobalIteratorHacks,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
typename AGlobalStepHacks,
|
||||
typename BGlobalStepHacks,
|
||||
typename CGlobalStepHacks,
|
||||
typename AGlobalMoveSliceWindowStepHacks,
|
||||
typename BGlobalMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
|
||||
// divide block work by [M, N]
|
||||
#if 0
|
||||
const auto k_block_work_num = K / Number<KPerBlock>{};
|
||||
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
|
||||
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
|
||||
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||
@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
||||
#else
|
||||
// Hack: this force result into SGPR
|
||||
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
|
||||
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
|
||||
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
|
||||
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||
@@ -134,23 +132,21 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_e_n_ho_wo_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
|
||||
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k_n_ho_wo_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
|
||||
@@ -184,47 +180,46 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<E, KPerBlock>,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_e_k_global_desc),
|
||||
decltype(a_e_k_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_e_k_global_desc,
|
||||
make_multi_index(0, k_block_data_on_global),
|
||||
a_e_k_desc,
|
||||
make_multi_index(0, 0));
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<E, KPerBlock>,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_e_k_global_desc),
|
||||
decltype(a_e_k_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_e_k_global_desc,
|
||||
make_multi_index(0, k_block_data_on_global),
|
||||
a_e_k_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_e_n_ho_wo_global_desc),
|
||||
decltype(b_e_n_ho_wo_thread_desc),
|
||||
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
1,
|
||||
true>(b_e_n_ho_wo_global_desc,
|
||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||
auto b_threadwise_transfer =
|
||||
ThreadwiseTensorSliceTransfer_v2<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_e_n_ho_wo_global_desc),
|
||||
decltype(b_e_n_ho_wo_thread_desc),
|
||||
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
1,
|
||||
true>(
|
||||
b_e_n_ho_wo_global_desc,
|
||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_shared_block, a_e_k_desc.GetElementSpaceSize());
|
||||
@@ -232,44 +227,45 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAcc,
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
|
||||
// initialize output thread tensor
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
|
||||
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
|
||||
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack =
|
||||
AGlobalMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
BGlobalMoveSliceWindowStepHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAB,
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
b_thread_even_buf, b_thread_odd_buf;
|
||||
|
||||
// LDS double buffer: preload data
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
||||
}
|
||||
@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
||||
@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_iterator_hacks);
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
@@ -351,23 +347,22 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
// output: register to global memory
|
||||
{
|
||||
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + k_thread_id * KPerThread;
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
decltype(c_k_n_ho_wo_global_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>(
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
decltype(c_k_n_ho_wo_global_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>(
|
||||
c_k_n_ho_wo_global_desc,
|
||||
make_multi_index(
|
||||
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
|
||||
@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
c_thread_buf,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
c_global_buf,
|
||||
c_k_n_ho_wo_global_tensor_iterator_hacks);
|
||||
c_k_n_ho_wo_global_tensor_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -24,13 +24,13 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
@@ -58,25 +58,25 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockClusterAdaptor*>((const void*)p_c_block_cluster_adaptor);
|
||||
const auto a_k0_m_k1_grid_desc = *reinterpret_cast<const AK0MK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
|
||||
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
|
||||
const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc));
|
||||
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
|
||||
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
@@ -126,13 +126,13 @@ template <index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -148,12 +148,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -203,9 +203,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
||||
|
||||
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
||||
@@ -217,10 +214,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||
|
||||
constexpr auto N0 = Number<CLayout.N1()>{};
|
||||
constexpr auto N1 = Number<CLayout.N0()>{};
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
|
||||
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
|
||||
@@ -269,11 +265,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
@@ -282,8 +273,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
@@ -301,67 +290,65 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
a_k0_m_k1_grid_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k0_m_k1_grid_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(
|
||||
b_k0_n_k1_grid_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k0_n_k1_grid_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
@@ -375,7 +362,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
NPerBlock % (NPerWave * NRepeat) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
|
||||
a_k0_m_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
@@ -384,7 +371,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
@@ -410,21 +397,19 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
|
||||
|
||||
constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
constexpr auto c_mr_nr_blk_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
vector_type<FloatAcc, BlkSize>,
|
||||
c_mr_nr_blk_desc.GetElementSpaceSize()>
|
||||
c_mr_nr_blk_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
@@ -432,15 +417,13 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{};
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack =
|
||||
AGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
@@ -449,10 +432,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
@@ -465,18 +446,16 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_iterator_hack);
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_iterator_hack);
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
@@ -506,7 +485,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr index_t N1 = CLayout.N0();
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<1>{},
|
||||
Number<1>{},
|
||||
@@ -515,7 +494,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
Number<M2>{},
|
||||
Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
|
||||
c_blk_buf_;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||
@@ -542,12 +521,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
@@ -573,7 +552,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_blk_buf_,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
}
|
||||
#else
|
||||
{
|
||||
@@ -581,11 +560,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, BlkSize> c_blk_buf_;
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -598,20 +574,20 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_grid_desc),
|
||||
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_grid_desc),
|
||||
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
@@ -629,7 +605,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
|
||||
return c_thread_idx_;
|
||||
};
|
||||
@@ -644,7 +620,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||
@@ -657,7 +633,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
@@ -670,7 +646,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
@@ -683,7 +659,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
||||
@@ -21,10 +21,10 @@ template <typename FloatA,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
|
||||
@@ -97,10 +97,9 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
amd_inner_product_dlop<FloatA, FloatB, FloatC>(
|
||||
a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -124,10 +123,10 @@ template <typename FloatA,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
||||
{
|
||||
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
||||
@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
|
||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
|
||||
inner_product<a_vector_t, b_vector_t, FloatC>(
|
||||
a_vec.template AsType<a_vector_t>()[I0],
|
||||
b_vec.template AsType<b_vector_t>()[I0],
|
||||
c_buf(Number<c_offset>{}));
|
||||
|
||||
@@ -19,9 +19,9 @@ template <typename FloatA,
|
||||
typename CDesc,
|
||||
index_t H,
|
||||
index_t W,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
template <typename ABuffer,
|
||||
@@ -57,8 +57,6 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto E = ADesc{}.GetLength(I0);
|
||||
constexpr auto K = ADesc{}.GetLength(I1);
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
|
||||
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
|
||||
#ifndef CK_THREADWISE_TENSOR_SET_HPP
|
||||
#define CK_THREADWISE_TENSOR_SET_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -11,12 +11,12 @@ namespace ck {
|
||||
// 1. Desc is known at compile-time
|
||||
// 2. Buffer is StaticBuffer
|
||||
// 3. OriginIdx is known at compile-time
|
||||
// 4. use #-iterator
|
||||
// 4. use #-step
|
||||
template <typename Data,
|
||||
typename Desc,
|
||||
typename SliceLengths,
|
||||
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseDynamicTensorSliceSet_v1
|
||||
typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceSet_v1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
@@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1
|
||||
constexpr auto origin_idx = to_multi_index(OriginIdx{});
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto access_idx) {
|
||||
constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx);
|
||||
constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
|
||||
|
||||
constexpr bool is_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
|
||||
@@ -1,9 +1,9 @@
|
||||
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
|
||||
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -57,20 +57,20 @@ template <typename SrcData,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3(
|
||||
const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
: dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx))
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_idx)
|
||||
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -78,19 +78,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename DstIteratorHacks>
|
||||
typename DstStepHacks>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -127,31 +127,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto dst_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto dst_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -235,13 +235,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -250,10 +250,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,11 +268,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_iterator_hacks =
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks);
|
||||
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
@@ -345,10 +345,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -374,20 +373,20 @@ template <typename SrcData,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_idx)
|
||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx))
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_idx)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
@@ -395,19 +394,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
|
||||
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer,
|
||||
typename DstBuffer,
|
||||
typename DstSliceOriginIdx,
|
||||
typename SrcIteratorHacks>
|
||||
typename SrcStepHacks>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
@@ -442,31 +441,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto src_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto src_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -548,13 +547,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -563,10 +562,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -581,11 +580,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_iterator_hacks =
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
|
||||
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -658,10 +657,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -693,23 +691,23 @@ template <typename SliceLengths,
|
||||
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
struct ThreadwiseTensorSliceTransfer_v3
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin)
|
||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
{
|
||||
// TODO: fix this
|
||||
static_assert(is_same<SrcData, DstData>::value,
|
||||
@@ -718,18 +716,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
@@ -757,31 +754,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto src_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto src_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -862,13 +859,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -877,17 +874,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstIteratorHacks>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
template <typename DstBuffer, typename DstStepHacks>
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
@@ -915,35 +911,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto dst_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator(
|
||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
||||
|
||||
return forward_iterator;
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto dst_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||
});
|
||||
|
||||
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator(
|
||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
||||
|
||||
return backward_iterator;
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -1026,13 +1018,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -1041,10 +1033,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1055,11 +1047,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_iterator_hacks =
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
@@ -1069,11 +1061,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_iterator_hacks =
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
||||
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -1206,18 +1198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
@@ -1225,10 +1216,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
@@ -1240,19 +1231,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto buffer_desc_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{}));
|
||||
make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{}));
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
@@ -1264,37 +1254,36 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_ref_idx is known at run-time
|
||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||
// 5. use #-iterator
|
||||
// 5. use #-step
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. DstOriginIdx is known at compile-time
|
||||
// 4. use direct address calculation
|
||||
// 3. vector access on src
|
||||
template <
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v4
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v4
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
@@ -1390,13 +1379,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
|
||||
constexpr auto src_ref_to_data_disp_idx =
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||
|
||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
||||
constexpr auto src_ref_to_data_disp_coord_step =
|
||||
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||
|
||||
auto src_data_coord = src_ref_coord_;
|
||||
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
|
||||
@@ -1435,10 +1423,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
|
||||
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
const auto src_slice_move_step_iter =
|
||||
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
|
||||
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -1,9 +1,9 @@
|
||||
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -30,7 +30,7 @@ template <typename SliceLengths,
|
||||
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -38,18 +38,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin)
|
||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
{
|
||||
// TODO: fix this
|
||||
static_assert(is_same<SrcData, DstData>::value,
|
||||
@@ -64,18 +64,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
@@ -92,13 +91,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(src_vector_tensor_lengths,
|
||||
SrcVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies_v2{},
|
||||
math::multiplies{},
|
||||
I1),
|
||||
SrcVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
||||
sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||
constexpr auto src_vector_desc =
|
||||
make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||
|
||||
// access order and lengths
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||
@@ -108,31 +107,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto src_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto src_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -219,13 +218,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -234,17 +233,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstIteratorHacks>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
template <typename DstBuffer, typename DstStepHacks>
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
@@ -261,13 +259,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(dst_vector_tensor_lengths,
|
||||
DstVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies_v2{},
|
||||
math::multiplies{},
|
||||
I1),
|
||||
DstVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto dst_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
||||
sequence_to_tuple_of_number(dst_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(dst_vector_tensor_strides));
|
||||
constexpr auto dst_vector_desc =
|
||||
make_naive_tensor_descriptor(sequence_to_tuple_of_number(dst_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(dst_vector_tensor_strides));
|
||||
|
||||
// dst access order and lengths
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
|
||||
@@ -277,35 +275,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto dst_forward_iterators = generate_tuple(
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator(
|
||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
||||
|
||||
return forward_iterator;
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto dst_backward_iterators = generate_tuple(
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator(
|
||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
||||
|
||||
return backward_iterator;
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
@@ -394,13 +388,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -409,10 +403,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -423,11 +417,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_iterator_hacks =
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
@@ -437,11 +431,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_iterator_hacks =
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
||||
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
@@ -564,18 +558,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
@@ -583,10 +576,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
@@ -598,19 +591,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto buffer_desc_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{}));
|
||||
make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{}));
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
@@ -622,25 +614,24 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_ref_idx is known at run-time
|
||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||
// 5. use #-iterator
|
||||
// 5. use #-step
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. DstOriginIdx is known at compile-time
|
||||
// 4. use direct address calculation
|
||||
// 3. vector access on src
|
||||
template <
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -649,12 +640,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
@@ -708,13 +699,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(src_vector_tensor_lengths,
|
||||
SrcVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies_v2{},
|
||||
math::multiplies{},
|
||||
I1),
|
||||
SrcVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
||||
sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||
constexpr auto src_vector_desc =
|
||||
make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||
|
||||
// access order and lengths
|
||||
constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||
@@ -734,13 +725,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
||||
constexpr auto src_ref_to_data_disp_idx =
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||
|
||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
||||
constexpr auto src_ref_to_data_disp_coord_step =
|
||||
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||
|
||||
auto src_data_coord = src_ref_coord_;
|
||||
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||
|
||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||
|
||||
@@ -775,10 +765,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
|
||||
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
const auto src_slice_move_step_iter =
|
||||
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
|
||||
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
|
||||
p_a, p_b, reg_c);
|
||||
@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
|
||||
}
|
||||
@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
|
||||
}
|
||||
@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
|
||||
}
|
||||
@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
|
||||
}
|
||||
|
||||
44
composable_kernel/include/utility/amd_address_space.hpp
Normal file
44
composable_kernel/include/utility/amd_address_space.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
#ifndef CK_AMD_ADDRESS_SPACE_HPP
|
||||
#define CK_AMD_ADDRESS_SPACE_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpaceEnum_t
|
||||
{
|
||||
Generic,
|
||||
Global,
|
||||
Lds,
|
||||
Sgpr,
|
||||
Vgpr,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
|
||||
{
|
||||
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
|
||||
// only c-style pointer cast seems be able to be compiled
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
|
||||
{
|
||||
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
|
||||
// only c-style pointer cast seems be able to be compiled
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T CONSTANT*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,34 +1,34 @@
|
||||
#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP
|
||||
#define CK_AMD_BUFFER_ADDRESSING_V2_HPP
|
||||
#ifndef CK_AMD_BUFFER_ADDRESSING_HPP
|
||||
#define CK_AMD_BUFFER_ADDRESSING_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
union BufferResource_v2
|
||||
union BufferResource
|
||||
{
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
int32x4_t data;
|
||||
int32x4_t content;
|
||||
StaticallyIndexedArray<T*, 2> address;
|
||||
StaticallyIndexedArray<int32_t, 4> range;
|
||||
StaticallyIndexedArray<int32_t, 4> config;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
|
||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
|
||||
{
|
||||
BufferResource_v2<T> wave_buffer_resource;
|
||||
BufferResource<T> wave_buffer_resource;
|
||||
|
||||
// wavewise base address (64 bit)
|
||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
// wavewise range (32 bit)
|
||||
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
|
||||
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
|
||||
// wavewise setting (32 bit)
|
||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return wave_buffer_resource.data;
|
||||
return wave_buffer_resource.content;
|
||||
}
|
||||
|
||||
// load
|
||||
@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type<T, N>::type
|
||||
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 2) p_src_wave to be a wavewise pointer.
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_v2(const T* p_src_wave,
|
||||
index_t src_thread_data_offset,
|
||||
bool src_thread_data_valid,
|
||||
index_t src_element_space)
|
||||
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
bool src_thread_element_valid,
|
||||
index_t src_element_space_size)
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space);
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
return amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
||||
return amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#else
|
||||
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_data_valid ? tmp : vector_t(0);
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
bool src_thread_element_valid,
|
||||
index_t src_element_space_size,
|
||||
T customized_value)
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
// 1) p_dst_wave must be global memory
|
||||
// 2) p_dst_wave to be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void
|
||||
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_data_offset,
|
||||
const bool dst_thread_data_valid,
|
||||
const index_t dst_element_space)
|
||||
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space);
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
||||
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_data_valid)
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
||||
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
@@ -1,188 +0,0 @@
|
||||
#ifndef CK_AMD_DLOP_HPP
|
||||
#define CK_AMD_DLOP_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c);
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<float, float, float>(const float& a, const float& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c += a * b;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I0],
|
||||
vector_type<float, 2>{b}.AsType<float>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I1],
|
||||
vector_type<float, 2>{b}.AsType<float>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I0],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I1],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I1],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I2],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I2],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I3],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_DLOP
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot2(a, b, c, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
|
||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
|
||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_inner_product_dlop<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a,
|
||||
const int8x4_t& b,
|
||||
int32_t& c)
|
||||
{
|
||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_inner_product_dlop<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a,
|
||||
const int8x8_t& b,
|
||||
int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_inner_product_dlop<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a,
|
||||
const int8x16_t& b,
|
||||
int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
|
||||
c);
|
||||
}
|
||||
#endif // CK_USE_AMD_DLOP
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -2,6 +2,9 @@
|
||||
#define CK_AMD_INLINE_ASM_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
|
||||
// TODO: deprecate all amd_assembly_outer_product_xxx
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -53,9 +56,9 @@ __device__ void
|
||||
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
@@ -114,11 +117,11 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
|
||||
float& c3)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2);
|
||||
const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3);
|
||||
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
|
||||
const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
@@ -160,11 +163,11 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
|
||||
{
|
||||
|
||||
// TODO remove pointer casting
|
||||
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
|
||||
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
|
||||
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
|
||||
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
|
||||
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);
|
||||
const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
|
||||
const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
|
||||
const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
|
||||
const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
|
||||
const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
|
||||
@@ -184,11 +187,11 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
|
||||
float& c3)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
|
||||
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
|
||||
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
|
||||
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
|
||||
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);
|
||||
const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
|
||||
const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
|
||||
const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
|
||||
const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
|
||||
const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
|
||||
|
||||
22
composable_kernel/include/utility/c_style_pointer_cast.hpp
Normal file
22
composable_kernel/include/utility/c_style_pointer_cast.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#ifndef CK_C_STYLE_POINTER_CAST_HPP
|
||||
#define CK_C_STYLE_POINTER_CAST_HPP
|
||||
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename PY,
|
||||
typename PX,
|
||||
typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
|
||||
__host__ __device__ PY c_style_pointer_cast(PX p_x)
|
||||
{
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic ignored "-Wcast-align"
|
||||
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -7,13 +7,14 @@
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_element_picker.hpp"
|
||||
#include "multi_index.hpp"
|
||||
#include "data_type_enum.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "data_type_helper.hpp"
|
||||
#include "data_type_enum.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
#include "functional4.hpp"
|
||||
#include "enable_if.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "math.hpp"
|
||||
#include "number.hpp"
|
||||
@@ -23,21 +24,21 @@
|
||||
#include "tuple.hpp"
|
||||
#include "tuple_helper.hpp"
|
||||
#include "type.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "magic_division.hpp"
|
||||
#include "amd_buffer_addressing_v2.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "amd_address_space.hpp"
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "static_buffer.hpp"
|
||||
#include "dynamic_buffer.hpp"
|
||||
|
||||
#include "inner_product.hpp"
|
||||
|
||||
// TODO: remove this
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
#endif
|
||||
|
||||
#if CK_USE_AMD_DLOP
|
||||
#include "amd_dlop.hpp"
|
||||
#endif
|
||||
|
||||
#if CK_USE_AMD_XDLOPS
|
||||
#include "amd_xdlops.hpp"
|
||||
#endif
|
||||
|
||||
@@ -7,19 +7,14 @@
|
||||
#endif
|
||||
#include "bfloat16_dev.hpp"
|
||||
|
||||
// address space for kernel parameter
|
||||
// "Constant" address space for kernel parameter
|
||||
#define CONSTANT __attribute__((address_space(4)))
|
||||
|
||||
// GPU target
|
||||
// should enable one and only one GPU target
|
||||
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
|
||||
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
|
||||
#error Need to define a single GPU target
|
||||
#endif
|
||||
|
||||
// HIP version
|
||||
#ifndef CK_HIP_VERSION_FLAT
|
||||
#define CK_HIP_VERSION_FLAT 0
|
||||
#error Need to define (only) one GPU target
|
||||
#endif
|
||||
|
||||
// launch bounds
|
||||
@@ -38,6 +33,16 @@
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#endif
|
||||
|
||||
// FMA instruction
|
||||
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900)
|
||||
#define CK_USE_AMD_V_MAC_F32
|
||||
#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \
|
||||
defined(CK_AMD_GPU_GFX1030)
|
||||
#define CK_USE_AMD_V_FMAC_F32
|
||||
#define CK_USE_AMD_V_DOT2_F32_F16
|
||||
#define CK_USE_AMD_V_DOT4_I32_I8
|
||||
#endif
|
||||
|
||||
// multi index
|
||||
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
|
||||
|
||||
@@ -46,13 +51,9 @@
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
// AMD DLOPS
|
||||
#ifndef CK_USE_AMD_DLOP
|
||||
#define CK_USE_AMD_DLOP 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_DLOP_INLINE_ASM
|
||||
#define CK_USE_AMD_DLOP_INLINE_ASM 1
|
||||
// AMD inner product (DLOP)
|
||||
#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
// AMD buffer addressing
|
||||
@@ -99,8 +100,8 @@
|
||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||
// thread-invariant, otherwise it's a bug
|
||||
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
|
||||
#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
|
||||
#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when compiling recursive lambda
|
||||
@@ -120,15 +121,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpaceEnum_t
|
||||
{
|
||||
Generic,
|
||||
Global,
|
||||
Lds,
|
||||
Sgpr,
|
||||
Vgpr
|
||||
};
|
||||
|
||||
enum InMemoryDataOperationEnum_t
|
||||
{
|
||||
Set,
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this enumerate should be synchronized with include/miopen.h
|
||||
typedef enum
|
||||
enum DataTypeEnum_t
|
||||
{
|
||||
Half = 0,
|
||||
Float = 1,
|
||||
@@ -14,7 +13,7 @@ typedef enum
|
||||
BFloat16 = 5,
|
||||
Double = 6,
|
||||
Unknown = 100,
|
||||
} DataTypeEnum_t;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_DATA_TYPE_HELPER_HPP
|
||||
#define CK_DATA_TYPE_HELPER_HPP
|
||||
#ifndef CK_DATA_TYPE_ENUM_HELPER_HPP
|
||||
#define CK_DATA_TYPE_ENUM_HELPER_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "data_type_enum.hpp"
|
||||
@@ -1,38 +1,49 @@
|
||||
#ifndef CK_DYNAMIC_BUFFER_HPP
|
||||
#define CK_DYNAMIC_BUFFER_HPP
|
||||
#ifndef CK_BUFFER_HPP
|
||||
#define CK_BUFFER_HPP
|
||||
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#include "amd_buffer_addressing_v2.hpp"
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
typename ElementSpaceSize,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct DynamicBuffer
|
||||
{
|
||||
using type = T;
|
||||
|
||||
T* p_data_;
|
||||
ElementSpaceSize element_space_size_;
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
||||
: p_data_{p_data}, element_space_size_{element_space_size}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data,
|
||||
ElementSpaceSize element_space_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data},
|
||||
element_space_size_{element_space_size},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
||||
|
||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
@@ -44,29 +55,50 @@ struct DynamicBuffer
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
p_data_, i, is_valid_offset, element_space_size_);
|
||||
bool constexpr use_amd_buffer_addressing = true;
|
||||
#else
|
||||
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_return_zero<
|
||||
remove_cv_t<remove_reference_t<T>>,
|
||||
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_customized_value<
|
||||
remove_cv_t<remove_reference_t<T>>,
|
||||
t_per_x>(
|
||||
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
|
||||
: X{invalid_element_value_};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
@@ -78,26 +110,26 @@ struct DynamicBuffer
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
x, p_data_, i, is_valid_offset, element_space_size_);
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
#else
|
||||
if(is_valid_offset)
|
||||
if(is_valid_element)
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
|
||||
{
|
||||
if(is_valid_offset)
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#else
|
||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
|
||||
// inefficient
|
||||
@@ -128,24 +160,24 @@ struct DynamicBuffer
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int8_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int8_t*>(&x);
|
||||
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int8_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int16_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int16_t*>(&x);
|
||||
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int16_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32_t*>(&x);
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x4_t>::value &&
|
||||
@@ -153,8 +185,8 @@ struct DynamicBuffer
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32_t*>(&x);
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x8_t>::value &&
|
||||
@@ -162,8 +194,8 @@ struct DynamicBuffer
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32x2_t*>(&x);
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x16_t>::value &&
|
||||
@@ -171,22 +203,22 @@ struct DynamicBuffer
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32x4_t*>(&x);
|
||||
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x4_t*>(&x);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(is_valid_offset)
|
||||
if(is_valid_element)
|
||||
{
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,12 +228,18 @@ struct DynamicBuffer
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
||||
typename T,
|
||||
typename ElementSpaceSize>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
||||
{
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto
|
||||
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
|
||||
{
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
|
||||
p, element_space_size, invalid_element_value};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
13
composable_kernel/include/utility/enable_if.hpp
Normal file
13
composable_kernel/include/utility/enable_if.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef CK_ENABLE_IF_HPP
|
||||
#define CK_ENABLE_IF_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <bool B, typename T = void>
|
||||
using enable_if = std::enable_if<B, T>;
|
||||
|
||||
template <bool B, typename T = void>
|
||||
using enable_if_t = typename std::enable_if<B, T>::type;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
207
composable_kernel/include/utility/inner_product.hpp
Normal file
207
composable_kernel/include/utility/inner_product.hpp
Normal file
@@ -0,0 +1,207 @@
|
||||
#ifndef CK_INNER_PRODUCT_HPP
|
||||
#define CK_INNER_PRODUCT_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
__device__ void inner_product(const TA& a, const TB& b, TC& c);
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c += a * b;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
inner_product(vector_type<float, 2>{a}.AsType<float>()[I0],
|
||||
vector_type<float, 2>{b}.AsType<float>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<float, 2>{a}.AsType<float>()[I1],
|
||||
vector_type<float, 2>{b}.AsType<float>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
inner_product(vector_type<float, 4>{a}.AsType<float>()[I0],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<float, 4>{a}.AsType<float>()[I1],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I1],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<float, 4>{a}.AsType<float>()[I2],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I2],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<float, 4>{a}.AsType<float>()[I3],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
|
||||
{
|
||||
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot2(a, b, c, false);
|
||||
#endif
|
||||
#else
|
||||
const auto convert = type_convert<int32_t>{};
|
||||
|
||||
const vector_type<half_t, 2> a_vector{a};
|
||||
const vector_type<half_t, 2> b_vector{b};
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
c += convert(a_vector.AsType<half_t>()[i]) * convert(b_vector.AsType<half_t>()[i]);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
|
||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
|
||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
|
||||
{
|
||||
#if defined(CK_USE_DOT4_I32_I8)
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
||||
#endif
|
||||
#else
|
||||
const auto convert = type_convert<int32_t>{};
|
||||
|
||||
const vector_type<int8_t, 4> a_vector{a};
|
||||
const vector_type<int8_t, 4> b_vector{b};
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
c += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a, const int8x8_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace math {
|
||||
@@ -27,13 +28,7 @@ struct minus
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct multiplies
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
struct multiplies_v2
|
||||
{
|
||||
template <typename A, typename B>
|
||||
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const
|
||||
@@ -184,9 +179,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
|
||||
return Number<r>{};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, gcd(ys...));
|
||||
@@ -199,9 +192,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
|
||||
{
|
||||
return lcm(x, lcm(ys...));
|
||||
|
||||
@@ -11,59 +11,11 @@ namespace ck {
|
||||
template <typename T>
|
||||
__host__ __device__ void print_array(const char* s, T a)
|
||||
{
|
||||
using data_type = decltype(a.At(Number<0>{}));
|
||||
constexpr index_t nsize = a.Size();
|
||||
|
||||
#if 0
|
||||
if constexpr(is_same<data_type, uint32_t>{})
|
||||
{
|
||||
printf("%s size %u, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
|
||||
printf("}\n");
|
||||
}
|
||||
else if constexpr(is_same<data_type, int32_t>{})
|
||||
{
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
||||
printf("}\n");
|
||||
}
|
||||
else if constexpr(is_same<data_type, bool>{})
|
||||
{
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
|
||||
printf("}\n");
|
||||
}
|
||||
#else
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
||||
printf("}\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void print_array_v2(const char* s, T a)
|
||||
{
|
||||
using data_type = decltype(a.At(Number<0>{}));
|
||||
constexpr index_t nsize = a.Size();
|
||||
|
||||
#if 0
|
||||
if constexpr(is_same<data_type, uint32_t>{})
|
||||
{
|
||||
printf("%s size %u, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
|
||||
printf("}\n");
|
||||
}
|
||||
else if constexpr(is_same<data_type, int32_t>{})
|
||||
{
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
|
||||
printf("}\n");
|
||||
}
|
||||
#else
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
|
||||
printf("}\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -685,8 +685,6 @@ __host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
constexpr auto seq_x = Sequence<Xs...>{};
|
||||
|
||||
return Sequence<(Y - Xs)...>{};
|
||||
}
|
||||
|
||||
|
||||
@@ -5,30 +5,66 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
index_t N,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
|
||||
: base{}, invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? At(i) : T{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? At(i) : invalid_element_value_;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
At(i) = x;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
||||
typename T,
|
||||
index_t N>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N>{};
|
||||
return StaticBuffer<BufferAddressSpace, T, N, true>{};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "integral_constant.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -20,10 +21,9 @@ struct TupleElement
|
||||
{
|
||||
__host__ __device__ constexpr TupleElement() = default;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
|
||||
bool>::type = false>
|
||||
template <typename T,
|
||||
typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
|
||||
{
|
||||
}
|
||||
@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
|
||||
{
|
||||
__host__ __device__ constexpr TupleImpl() = default;
|
||||
|
||||
template <
|
||||
typename Y,
|
||||
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
|
||||
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
|
||||
bool>::type = false>
|
||||
template <typename Y,
|
||||
typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
|
||||
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr TupleImpl(Y&& y)
|
||||
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
|
||||
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
|
||||
{
|
||||
@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
__host__ __device__ constexpr Tuple() = default;
|
||||
|
||||
template <typename Y,
|
||||
typename std::enable_if<
|
||||
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
|
||||
bool>::type = false>
|
||||
typename enable_if<sizeof...(Xs) == 1 &&
|
||||
!is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2,
|
||||
bool>::type = false>
|
||||
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
|
||||
false>
|
||||
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define CK_TYPE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -22,10 +23,7 @@ template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
template <typename T>
|
||||
constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
|
||||
{
|
||||
return static_cast<typename std::remove_reference<T>::type&&>(t);
|
||||
}
|
||||
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||
|
||||
template <typename T>
|
||||
struct is_known_at_compile_time;
|
||||
@@ -42,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>>
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Y,
|
||||
typename X,
|
||||
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
__host__ __device__ constexpr Y as_type(X x)
|
||||
{
|
||||
union AsType
|
||||
|
||||
@@ -0,0 +1,370 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
|
||||
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
|
||||
constexpr index_t KPerThread = CK_PARAM_KPerThread;
|
||||
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
|
||||
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
|
||||
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
|
||||
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K_M0_M1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1>;
|
||||
using ABlockTransferThreadClusterLengths_K_M0_M1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_M1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_M1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K_N0_N1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1>;
|
||||
using BBlockTransferThreadClusterLengths_K_N0_N1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_N1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_N1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k_m0_m1_grid_desc,
|
||||
void* p_b_k_n0_n1_grid_desc,
|
||||
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW));
|
||||
|
||||
const auto a_k_m_grid_desc = descs[I0];
|
||||
const auto b_k_n_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<decltype(a_k_m0_m1_grid_desc)*>(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc;
|
||||
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
|
||||
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
};
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1));
|
||||
|
||||
constexpr auto a_k_m_grid_desc = descs[I0];
|
||||
constexpr auto b_k_n_grid_desc = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
constexpr auto b_k_n0_n1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k_m0_m1_grid_desc =
|
||||
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc =
|
||||
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
};
|
||||
@@ -0,0 +1,358 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
@@ -0,0 +1,357 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
|
||||
int n,
|
||||
int hi,
|
||||
int wi,
|
||||
int c,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c));
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c));
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||
constexpr auto wei_k_y_x_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_dlops_v1r2.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
@@ -62,23 +62,39 @@ constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBloc
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
|
||||
|
||||
extern "C" __global__ void
|
||||
dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
||||
index_t C,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t K,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t ConvStrideH,
|
||||
index_t ConvStrideW,
|
||||
index_t ConvDilationH,
|
||||
index_t ConvDilationW,
|
||||
index_t InLeftPadH,
|
||||
index_t InLeftPadW,
|
||||
index_t InRightPadH,
|
||||
index_t InRightPadW,
|
||||
void* p_desc_tuple)
|
||||
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_,
|
||||
int C_,
|
||||
int Hi_,
|
||||
int Wi_,
|
||||
int K_,
|
||||
int Y_,
|
||||
int X_,
|
||||
int ConvStrideH_,
|
||||
int ConvStrideW_,
|
||||
int ConvDilationH_,
|
||||
int ConvDilationW_,
|
||||
int InLeftPadH_,
|
||||
int InLeftPadW_,
|
||||
int InRightPadH_,
|
||||
int InRightPadW_,
|
||||
void* p_desc_tuple)
|
||||
{
|
||||
index_t N = static_cast<index_t>(N_);
|
||||
index_t C = static_cast<index_t>(C_);
|
||||
index_t Hi = static_cast<index_t>(Hi_);
|
||||
index_t Wi = static_cast<index_t>(Wi_);
|
||||
index_t K = static_cast<index_t>(K_);
|
||||
index_t Y = static_cast<index_t>(Y_);
|
||||
index_t X = static_cast<index_t>(X_);
|
||||
index_t ConvStrideH = static_cast<index_t>(ConvStrideH_);
|
||||
index_t ConvStrideW = static_cast<index_t>(ConvStrideW_);
|
||||
index_t ConvDilationH = static_cast<index_t>(ConvDilationH_);
|
||||
index_t ConvDilationW = static_cast<index_t>(ConvDilationW_);
|
||||
index_t InLeftPadH = static_cast<index_t>(InLeftPadH_);
|
||||
index_t InLeftPadW = static_cast<index_t>(InLeftPadW_);
|
||||
index_t InRightPadH = static_cast<index_t>(InRightPadH_);
|
||||
index_t InRightPadW = static_cast<index_t>(InRightPadW_);
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -88,12 +104,9 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
||||
const index_t Wo =
|
||||
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C, Hi, Wi));
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C, Y, X));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
@@ -114,7 +127,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridIteratorHacks =
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
@@ -126,7 +139,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
@@ -138,7 +151,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(
|
||||
using CGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
@@ -154,13 +167,13 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowIteratorHacks =
|
||||
using BGridMoveSliceWindowStepHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
@@ -194,11 +207,11 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
@@ -220,7 +233,7 @@ extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -232,11 +245,11 @@ extern "C" __global__ void
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
@@ -257,7 +270,7 @@ extern "C" __global__ void
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridIteratorHacks =
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
@@ -269,7 +282,7 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
@@ -281,7 +294,7 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(
|
||||
using CGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
@@ -297,13 +310,13 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowIteratorHacks =
|
||||
using BGridMoveSliceWindowStepHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
@@ -337,11 +350,11 @@ extern "C" __global__ void
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
@@ -1,374 +0,0 @@
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
|
||||
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
|
||||
constexpr index_t KPerThread = CK_PARAM_KPerThread;
|
||||
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
|
||||
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
|
||||
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
|
||||
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K_M0_M1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1>;
|
||||
using ABlockTransferThreadClusterLengths_K_M0_M1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_M1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_M1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K_N0_N1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1>;
|
||||
using BBlockTransferThreadClusterLengths_K_N0_N1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_N1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_N1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||
|
||||
extern "C" __global__ void
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k_m0_m1_grid_desc,
|
||||
void* p_b_k_n0_n1_grid_desc,
|
||||
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW));
|
||||
|
||||
const auto a_k_m_grid_desc = descs[I0];
|
||||
const auto b_k_n_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<decltype(a_k_m0_m1_grid_desc)*>(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc;
|
||||
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
|
||||
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
};
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1));
|
||||
|
||||
constexpr auto a_k_m_grid_desc = descs[I0];
|
||||
constexpr auto b_k_n_grid_desc = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
constexpr auto b_k_n0_n1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k_m0_m1_grid_desc =
|
||||
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc =
|
||||
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
};
|
||||
@@ -1,362 +0,0 @@
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
false>;
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
@@ -1,362 +0,0 @@
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
|
||||
int n,
|
||||
int hi,
|
||||
int wi,
|
||||
int c,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, hi, wi, c));
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, y, x, c));
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, ho, wo, k));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256));
|
||||
constexpr auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 3, 3, 256));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridIteratorHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridIteratorHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
false>;
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
5670
external/half/include/half.hpp
vendored
5670
external/half/include/half.hpp
vendored
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,2 @@
|
||||
add_subdirectory(host_tensor)
|
||||
add_subdirectory(online_compile)
|
||||
add_subdirectory(driver_offline)
|
||||
add_subdirectory(driver_online)
|
||||
|
||||
@@ -9,11 +9,10 @@ include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
${PROJECT_SOURCE_DIR}/external/half/include
|
||||
)
|
||||
|
||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE conv_fwd_driver_offline.cpp)
|
||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE conv_bwd_driver_offline.cpp)
|
||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
|
||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
@@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
@@ -215,7 +207,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
@@ -223,7 +215,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
@@ -231,7 +223,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
|
||||
@@ -251,15 +243,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -295,11 +287,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
@@ -307,11 +299,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
in_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
@@ -319,16 +311,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
@@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
@@ -187,7 +179,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
@@ -195,7 +187,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
@@ -203,7 +195,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||
@@ -223,15 +215,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -271,11 +263,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
@@ -283,11 +275,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
out_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
in_m0_m1_m2_n_grid_iterator_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_m1_m2_n_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
@@ -295,16 +287,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_gemm_dlops_v1r2.hpp"
|
||||
#include "driver_gemm_dlops_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
@@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
@@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
@@ -98,7 +89,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
@@ -108,7 +99,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
@@ -116,7 +107,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
@@ -130,10 +121,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
@@ -142,7 +133,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_dlops_v1r2<
|
||||
float ave_time = driver_gemm_dlops_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -180,26 +171,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_step_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
@@ -1,7 +1,7 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -13,7 +13,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
@@ -48,12 +48,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
@@ -212,9 +209,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
#if 0
|
||||
float ave_time = launch_kernel_dynamic_gemm_xdlops_v1
|
||||
float ave_time = launch_kernel_gemm_xdlops_v1
|
||||
#else
|
||||
float ave_time = launch_kernel_dynamic_gemm_xdlops_v2
|
||||
float ave_time = launch_kernel_gemm_xdlops_v2
|
||||
#endif
|
||||
<BlockSize,
|
||||
TInWei,
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_dlops_v1r3.hpp"
|
||||
#include "driver_gemm_dlops_v1r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
@@ -49,14 +44,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
@@ -163,7 +155,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
||||
@@ -173,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
||||
@@ -183,7 +175,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
||||
@@ -197,15 +189,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
||||
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_dlops_v1r3<
|
||||
float ave_time = driver_gemm_dlops_v1r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -239,22 +231,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
@@ -262,16 +254,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
@@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
@@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
@@ -101,12 +92,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
@@ -114,7 +105,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
@@ -132,15 +123,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -176,26 +167,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r2.hpp"
|
||||
#include "driver_gemm_xdlops_v2r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
@@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
@@ -129,12 +121,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
@@ -142,7 +134,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
@@ -152,15 +144,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r2<
|
||||
float ave_time = driver_gemm_xdlops_v2r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -195,22 +187,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
Sequence<2, 3, 0, 1>,
|
||||
2,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
@@ -218,9 +210,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
||||
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
@@ -49,12 +49,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
@@ -185,12 +182,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
@@ -198,7 +195,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
@@ -216,15 +213,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -259,11 +256,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
@@ -271,11 +268,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
@@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
@@ -241,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
||||
@@ -249,7 +241,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
@@ -257,7 +249,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||
@@ -275,15 +267,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -319,11 +311,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
@@ -331,11 +323,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
in_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
@@ -343,16 +335,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
@@ -1,8 +1,8 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
@@ -15,7 +15,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
||||
void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
@@ -26,7 +26,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
ck::index_t /* nrepeat */)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -85,12 +85,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
|
||||
const auto in_n_c0_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
|
||||
const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||
@@ -3,7 +3,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_contraction_dlops_v1r2.hpp"
|
||||
#include "driver_contraction_dlops_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
@@ -15,7 +15,7 @@ template <typename TInWei,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_desc_n_c_hi_wi =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_desc_k_c_y_x =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_desc_n_k_ho_wo =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
|
||||
@@ -133,7 +130,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_grid_iterator_hacks =
|
||||
constexpr auto wei_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
@@ -145,7 +142,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto in_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto in_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
@@ -157,7 +154,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto out_grid_iterator_hacks = make_tuple(
|
||||
constexpr auto out_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
@@ -173,14 +170,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
||||
|
||||
constexpr auto wei_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
||||
constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_grid_move_slice_window_iterator_hacks =
|
||||
constexpr auto in_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_contraction_dlops_v1r2<
|
||||
float ave_time = driver_contraction_dlops_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
@@ -214,26 +211,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
CThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_grid_iterator_hacks),
|
||||
decltype(in_grid_iterator_hacks),
|
||||
decltype(out_grid_iterator_hacks),
|
||||
decltype(wei_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_grid_move_slice_window_iterator_hacks)>(
|
||||
decltype(wei_grid_step_hacks),
|
||||
decltype(in_grid_step_hacks),
|
||||
decltype(out_grid_step_hacks),
|
||||
decltype(wei_grid_move_slice_window_step_hacks),
|
||||
decltype(in_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_grid_desc_gk0_gm0_gm1_gk1,
|
||||
in_grid_desc_gk0_gn0_gn1_gk1,
|
||||
out_grid_desc_gm0_gm1_gn0_gn1,
|
||||
wei_grid_iterator_hacks,
|
||||
in_grid_iterator_hacks,
|
||||
out_grid_iterator_hacks,
|
||||
wei_grid_move_slice_window_iterator_hacks,
|
||||
in_grid_move_slice_window_iterator_hacks,
|
||||
wei_grid_step_hacks,
|
||||
in_grid_step_hacks,
|
||||
out_grid_step_hacks,
|
||||
wei_grid_move_slice_window_step_hacks,
|
||||
in_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo) /
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_dlops_v1r2.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
@@ -39,24 +39,24 @@ template <ck::index_t BlockSize,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float
|
||||
driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
ck::index_t nrepeat)
|
||||
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
@@ -70,7 +70,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
@@ -104,11 +104,11 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
@@ -116,7 +116,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
|
||||
{
|
||||
throw std::runtime_error("wrong! "
|
||||
"GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
||||
"GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
||||
"GM0_GM1_GN0_GN1 has invalid setting");
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -205,7 +204,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -232,7 +230,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -259,7 +256,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_dlops_v2.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const ck::DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -82,14 +82,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
@@ -98,7 +98,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
@@ -108,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
@@ -118,8 +118,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
@@ -136,13 +136,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_iterator_hacks =
|
||||
constexpr auto a_e_k_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
@@ -152,12 +152,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
@@ -169,7 +169,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
|
||||
#if 1
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3<
|
||||
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
@@ -202,11 +202,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_iterator_hacks),
|
||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
||||
decltype(a_e_k_global_step_hacks),
|
||||
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
||||
|
||||
@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -338,10 +334,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
float perf =
|
||||
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_dlops_v2.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const ck::DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -93,14 +93,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
<< std::endl;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
@@ -109,7 +109,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
@@ -119,7 +119,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
@@ -129,8 +129,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, 0, OutRightPadH),
|
||||
@@ -149,13 +149,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_iterator_hacks =
|
||||
constexpr auto a_e_k_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
@@ -165,12 +165,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
@@ -181,7 +181,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3<
|
||||
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
@@ -214,11 +214,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_iterator_hacks),
|
||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
||||
decltype(a_e_k_global_step_hacks),
|
||||
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
@@ -354,10 +350,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
float perf =
|
||||
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
@@ -1,415 +0,0 @@
|
||||
#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R2
|
||||
#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R2
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_dlops_v1r2.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t M1N1ThreadClusterM10,
|
||||
ck::index_t M1N1ThreadClusterN10,
|
||||
ck::index_t M1N1ThreadClusterM11,
|
||||
ck::index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r2 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
|
||||
|
||||
{
|
||||
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc));
|
||||
DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc);
|
||||
b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@@ -1,411 +0,0 @@
|
||||
#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R3
|
||||
#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_dlops_v1r3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs,
|
||||
typename M1N1ThreadClusterN1Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
|
||||
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc));
|
||||
DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc);
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
@@ -1,196 +0,0 @@
|
||||
#ifndef DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
|
||||
#define DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t MPerWave,
|
||||
ck::index_t NPerWave,
|
||||
ck::index_t K1,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
__host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
CAccessOrderMRepeatNRepeat>;
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
|
||||
|
||||
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0M1M2NGridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
|
||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||
|
||||
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
||||
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time =
|
||||
launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
413
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
Normal file
413
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
Normal file
@@ -0,0 +1,413 @@
|
||||
#ifndef DRIVER_GEMM_DLOPS_V1R2
|
||||
#define DRIVER_GEMM_DLOPS_V1R2
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r2.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t M1N1ThreadClusterM10,
|
||||
ck::index_t M1N1ThreadClusterN10,
|
||||
ck::index_t M1N1ThreadClusterM11,
|
||||
ck::index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
|
||||
|
||||
{
|
||||
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc));
|
||||
DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc);
|
||||
b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
418
host/driver_offline/include/driver_gemm_dlops_v1r3.hpp
Normal file
418
host/driver_offline/include/driver_gemm_dlops_v1r3.hpp
Normal file
@@ -0,0 +1,418 @@
|
||||
#ifndef DRIVER_GEMM_DLOPS_V1R3
|
||||
#define DRIVER_GEMM_DLOPS_V1R3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t M1PerThread,
|
||||
ck::index_t N1PerThread,
|
||||
ck::index_t KPerThread,
|
||||
typename M1N1ThreadClusterM1Xs,
|
||||
typename M1N1ThreadClusterN1Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM1Xs,
|
||||
M1N1ThreadClusterN1Xs,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
|
||||
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc));
|
||||
DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc);
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
191
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
Normal file
191
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
Normal file
@@ -0,0 +1,191 @@
|
||||
#ifndef DRIVER_GEMM_XDLOPS_V2R3
|
||||
#define DRIVER_GEMM_XDLOPS_V2R3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t MPerWave,
|
||||
ck::index_t NPerWave,
|
||||
ck::index_t K1,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
CAccessOrderMRepeatNRepeat>;
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
|
||||
|
||||
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0M1M2NGridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
|
||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||
|
||||
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
||||
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -12,10 +12,10 @@
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv_bwd_data.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 1
|
||||
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
|
||||
|
||||
@@ -37,7 +37,7 @@ int main(int argc, char* argv[])
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
#if USE_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
@@ -46,29 +46,29 @@ int main(int argc, char* argv[])
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t N = atoi(argv[7]);
|
||||
const index_t K = atoi(argv[8]);
|
||||
const index_t C = atoi(argv[9]);
|
||||
const index_t Y = atoi(argv[10]);
|
||||
const index_t X = atoi(argv[11]);
|
||||
const index_t Hi = atoi(argv[12]);
|
||||
const index_t Wi = atoi(argv[13]);
|
||||
const index_t N = std::stoi(argv[7]);
|
||||
const index_t K = std::stoi(argv[8]);
|
||||
const index_t C = std::stoi(argv[9]);
|
||||
const index_t Y = std::stoi(argv[10]);
|
||||
const index_t X = std::stoi(argv[11]);
|
||||
const index_t Hi = std::stoi(argv[12]);
|
||||
const index_t Wi = std::stoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = atoi(argv[14]);
|
||||
const index_t conv_stride_w = atoi(argv[15]);
|
||||
const index_t conv_dilation_h = atoi(argv[16]);
|
||||
const index_t conv_dilation_w = atoi(argv[17]);
|
||||
const index_t in_left_pad_h = atoi(argv[18]);
|
||||
const index_t in_left_pad_w = atoi(argv[19]);
|
||||
const index_t in_right_pad_h = atoi(argv[20]);
|
||||
const index_t in_right_pad_w = atoi(argv[21]);
|
||||
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
@@ -83,12 +83,12 @@ int main(int argc, char* argv[])
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
@@ -115,23 +115,19 @@ int main(int argc, char* argv[])
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
switch(layout)
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
// NCHW
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
@@ -144,9 +140,9 @@ int main(int argc, char* argv[])
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
// NHWC
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
@@ -159,8 +155,10 @@ int main(int argc, char* argv[])
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in_host(in_lengths_host);
|
||||
@@ -213,40 +211,8 @@ int main(int argc, char* argv[])
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
#else
|
||||
const auto in_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<C>{}, Number<Hi>{}, Number<Wi>{});
|
||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<C>{}, Number<Y>{}, Number<X>{});
|
||||
const auto out_lengths_dev =
|
||||
make_tuple(Number<N>{}, Number<K>{}, Number<Ho>{}, Number<Wo>{});
|
||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
||||
const auto conv_dilations_dev =
|
||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
||||
const auto in_right_pads_dev =
|
||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
||||
#endif
|
||||
|
||||
return make_tuple(in_lengths_dev,
|
||||
wei_lengths_dev,
|
||||
out_lengths_dev,
|
||||
conv_strides_dev,
|
||||
conv_dilations_dev,
|
||||
in_left_pads_dev,
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
@@ -277,8 +243,6 @@ int main(int argc, char* argv[])
|
||||
in_right_pads_dev);
|
||||
};
|
||||
|
||||
const auto nhwc_desc = f_make_for_device_nhwc();
|
||||
|
||||
#if USE_CONV_BWD_V4R1_XDL_NHWC
|
||||
if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC)
|
||||
{
|
||||
@@ -289,20 +253,20 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -316,20 +280,20 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in_device,
|
||||
wei,
|
||||
out,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -12,17 +12,17 @@
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 1
|
||||
#define USE_CONV_FWD_V6R1_NCHW 1
|
||||
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
|
||||
@@ -49,7 +49,7 @@ int main(int argc, char* argv[])
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
#if USE_DYNAMIC_MODE
|
||||
#if USE_MODE
|
||||
// dynamic mode
|
||||
if(argc != 22)
|
||||
{
|
||||
@@ -58,29 +58,29 @@ int main(int argc, char* argv[])
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
const index_t N = atoi(argv[7]);
|
||||
const index_t K = atoi(argv[8]);
|
||||
const index_t C = atoi(argv[9]);
|
||||
const index_t Y = atoi(argv[10]);
|
||||
const index_t X = atoi(argv[11]);
|
||||
const index_t Hi = atoi(argv[12]);
|
||||
const index_t Wi = atoi(argv[13]);
|
||||
const index_t N = std::stoi(argv[7]);
|
||||
const index_t K = std::stoi(argv[8]);
|
||||
const index_t C = std::stoi(argv[9]);
|
||||
const index_t Y = std::stoi(argv[10]);
|
||||
const index_t X = std::stoi(argv[11]);
|
||||
const index_t Hi = std::stoi(argv[12]);
|
||||
const index_t Wi = std::stoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = atoi(argv[14]);
|
||||
const index_t conv_stride_w = atoi(argv[15]);
|
||||
const index_t conv_dilation_h = atoi(argv[16]);
|
||||
const index_t conv_dilation_w = atoi(argv[17]);
|
||||
const index_t in_left_pad_h = atoi(argv[18]);
|
||||
const index_t in_left_pad_w = atoi(argv[19]);
|
||||
const index_t in_right_pad_h = atoi(argv[20]);
|
||||
const index_t in_right_pad_w = atoi(argv[21]);
|
||||
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
@@ -95,12 +95,12 @@ int main(int argc, char* argv[])
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2]));
|
||||
const bool do_verification = std::stoi(argv[3]);
|
||||
const int init_method = std::stoi(argv[4]);
|
||||
const bool do_log = std::stoi(argv[5]);
|
||||
const int nrepeat = std::stoi(argv[6]);
|
||||
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 192;
|
||||
@@ -142,10 +142,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
switch(layout)
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
// NCHW
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
@@ -158,9 +156,9 @@ int main(int argc, char* argv[])
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
// NHWC
|
||||
}
|
||||
else if(layout == ConvTensorLayout::NHWC)
|
||||
{
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
@@ -173,8 +171,10 @@ int main(int argc, char* argv[])
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
else
|
||||
{
|
||||
std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
@@ -228,7 +228,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
@@ -260,7 +260,7 @@ int main(int argc, char* argv[])
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
#if USE_DYNAMIC_MODE
|
||||
#if USE_MODE
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
@@ -301,20 +301,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -328,20 +327,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -355,20 +353,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -382,21 +379,20 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
16,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||
16,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -410,9 +406,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
@@ -437,9 +433,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
@@ -467,7 +463,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
check_error(out_host, out_device);
|
||||
|
||||
#if 0
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
@@ -475,6 +470,5 @@ int main(int argc, char* argv[])
|
||||
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_BINARY_DIR}/host/online_compile/include
|
||||
${PROJECT_SOURCE_DIR}/host/online_compile/include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/host/solver/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
${PROJECT_SOURCE_DIR}/external/half/include
|
||||
)
|
||||
|
||||
set(CONV_FWD_DRIVER_ONLINE_SOURCE conv_fwd_driver_online.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_online ${CONV_FWD_DRIVER_ONLINE_SOURCE})
|
||||
|
||||
target_link_libraries(conv_fwd_driver_online PRIVATE host_tensor)
|
||||
target_link_libraries(conv_fwd_driver_online PRIVATE online_compile)
|
||||
@@ -1,453 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <stdlib.h>
|
||||
#include <half.hpp>
|
||||
#include "config.hpp"
|
||||
#include "print.hpp"
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "host_tensor_generator.hpp"
|
||||
#include "conv_common.hpp"
|
||||
#include "host_conv.hpp"
|
||||
#include "device_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "hipCheck.hpp"
|
||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||
#define USE_CONV_FWD_V6R1_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
|
||||
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW, // 0
|
||||
V6R1NCHW, // 1
|
||||
V4R4XDLNCHW, // 2
|
||||
V4R4XDLNHWC // 3
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace ck_driver;
|
||||
using size_t = std::size_t;
|
||||
|
||||
hipStream_t stream;
|
||||
online_compile::Handle* handle;
|
||||
|
||||
MY_HIP_CHECK(hipStreamCreate(&stream));
|
||||
|
||||
handle = new online_compile::Handle(stream);
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
if(argc != 22)
|
||||
{
|
||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
||||
const bool do_verification = atoi(argv[3]);
|
||||
const int init_method = atoi(argv[4]);
|
||||
const bool do_log = atoi(argv[5]);
|
||||
const int nrepeat = atoi(argv[6]);
|
||||
|
||||
const index_t N = atoi(argv[7]);
|
||||
const index_t K = atoi(argv[8]);
|
||||
const index_t C = atoi(argv[9]);
|
||||
const index_t Y = atoi(argv[10]);
|
||||
const index_t X = atoi(argv[11]);
|
||||
const index_t Hi = atoi(argv[12]);
|
||||
const index_t Wi = atoi(argv[13]);
|
||||
|
||||
const index_t conv_stride_h = atoi(argv[14]);
|
||||
const index_t conv_stride_w = atoi(argv[15]);
|
||||
const index_t conv_dilation_h = atoi(argv[16]);
|
||||
const index_t conv_dilation_w = atoi(argv[17]);
|
||||
const index_t in_left_pad_h = atoi(argv[18]);
|
||||
const index_t in_left_pad_w = atoi(argv[19]);
|
||||
const index_t in_right_pad_h = atoi(argv[20]);
|
||||
const index_t in_right_pad_w = atoi(argv[21]);
|
||||
|
||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||
|
||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 0
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
|
||||
switch(layout)
|
||||
{
|
||||
case ConvTensorLayout::NCHW:
|
||||
// NCHW
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
||||
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
||||
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||
break;
|
||||
case ConvTensorLayout::NHWC:
|
||||
// NHWC
|
||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
|
||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
||||
|
||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||
break;
|
||||
default: throw std::runtime_error("wrong! not implemented");
|
||||
}
|
||||
|
||||
Tensor<in_data_t> in(in_lengths_host);
|
||||
Tensor<in_data_t> wei(wei_lengths_host);
|
||||
Tensor<out_data_t> out_host(out_lengths_host);
|
||||
Tensor<out_data_t> out_device(out_lengths_host);
|
||||
|
||||
std::cout << "layout: " << layout << std::endl;
|
||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
||||
ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: ");
|
||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
auto f_make_for_device_nchw = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||
|
||||
return make_tuple(in_lengths_dev, wei_lengths_dev, out_lengths_dev);
|
||||
};
|
||||
|
||||
auto f_make_for_device_nhwc = [&]() {
|
||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||
|
||||
return make_tuple(in_lengths_dev, wei_lengths_dev, out_lengths_dev);
|
||||
};
|
||||
|
||||
const auto conv_strides = make_tuple(conv_stride_h, conv_stride_w);
|
||||
const auto conv_dilations = make_tuple(conv_dilation_h, conv_dilation_w);
|
||||
const auto in_left_pads = make_tuple(in_left_pad_h, in_left_pad_w);
|
||||
const auto in_right_pads = make_tuple(in_right_pad_h, in_right_pad_w);
|
||||
|
||||
#if USE_CONV_FWD_V4R4_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw;
|
||||
|
||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V6R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V6R1NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
#if 1
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
|
||||
get_datatype_enum_from_type<in_data_t>::value,
|
||||
get_datatype_enum_from_type<acc_data_t>::value,
|
||||
get_datatype_enum_from_type<out_data_t>::value,
|
||||
256,
|
||||
4,
|
||||
1,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
4,
|
||||
4,
|
||||
1,
|
||||
{8, 2},
|
||||
{8, 2},
|
||||
{4, 1, 1, 1, 1},
|
||||
{2, 1, 1, 128, 1},
|
||||
{4, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
{1, 4, 1, 1, 1},
|
||||
{8, 1, 1, 32, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
4,
|
||||
true,
|
||||
true};
|
||||
#elif 0
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
|
||||
get_datatype_enum_from_type<in_data_t>::value,
|
||||
get_datatype_enum_from_type<acc_data_t>::value,
|
||||
get_datatype_enum_from_type<out_data_t>::value,
|
||||
256,
|
||||
4,
|
||||
2,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
4,
|
||||
4,
|
||||
1,
|
||||
{8, 2},
|
||||
{8, 2},
|
||||
{4, 1, 1, 1, 2},
|
||||
{2, 1, 1, 128, 1},
|
||||
{4, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
{1, 4, 1, 1, 2},
|
||||
{8, 1, 1, 32, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
4,
|
||||
true,
|
||||
true};
|
||||
#elif 1
|
||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
|
||||
get_datatype_enum_from_type<in_data_t>::value,
|
||||
get_datatype_enum_from_type<acc_data_t>::value,
|
||||
get_datatype_enum_from_type<out_data_t>::value,
|
||||
256,
|
||||
4,
|
||||
4,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
4,
|
||||
4,
|
||||
1,
|
||||
{8, 2},
|
||||
{8, 2},
|
||||
{4, 1, 1, 1, 4},
|
||||
{2, 1, 1, 128, 1},
|
||||
{4, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
{1, 4, 1, 1, 4},
|
||||
{8, 1, 1, 32, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1},
|
||||
4,
|
||||
true,
|
||||
true};
|
||||
#endif
|
||||
|
||||
online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
compile_param,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4_XDLOPS_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R4XDLNCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
||||
|
||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4_XDLOPS_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4XDLNHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable =
|
||||
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
|
||||
|
||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk<
|
||||
in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(handle,
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
tunable,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
host_direct_convolution(in,
|
||||
wei,
|
||||
out_host,
|
||||
make_tuple(conv_stride_h, conv_stride_w),
|
||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
||||
layout);
|
||||
|
||||
check_error(out_host, out_device);
|
||||
|
||||
#if 0
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
delete handle;
|
||||
MY_HIP_CHECK(hipStreamDestroy(stream));
|
||||
}
|
||||
@@ -1,395 +0,0 @@
|
||||
#pragma once
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "online_driver_common.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw {
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_network_config_string_from_types()
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::string out;
|
||||
|
||||
out += std::to_string(get_datatype_enum_from_type<TInWei>::value) + "_" +
|
||||
std::to_string(get_datatype_enum_from_type<TAcc>::value) + "_" +
|
||||
std::to_string(get_datatype_enum_from_type<TOut>::value);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt)
|
||||
{
|
||||
std::string out("TUN_");
|
||||
|
||||
out += std::to_string(pt->BlockSize) + "_";
|
||||
|
||||
out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" +
|
||||
std::to_string(pt->KPerBlock) + "_";
|
||||
out += std::to_string(pt->M1PerThread) + "x" + std::to_string(pt->N1PerThread) + "x" +
|
||||
std::to_string(pt->KPerThread) + "_";
|
||||
out += std::to_string(pt->M1N1ThreadClusterM10) + "x" +
|
||||
std::to_string(pt->M1N1ThreadClusterN10) + "x" +
|
||||
std::to_string(pt->M1N1ThreadClusterM11) + "x" +
|
||||
std::to_string(pt->M1N1ThreadClusterN11) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_M1) + "_";
|
||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_N1) + "_";
|
||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_definition_string_from_types()
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::string out;
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TInWei>::value) +
|
||||
" -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TAcc>::value) +
|
||||
" -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TOut>::value);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt)
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
||||
|
||||
out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) +
|
||||
" -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) +
|
||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
||||
out += " -DCK_PARAM_M1PerThread=" + std::to_string(pt->M1PerThread) +
|
||||
" -DCK_PARAM_N1PerThread=" + std::to_string(pt->N1PerThread) +
|
||||
" -DCK_PARAM_KPerThread=" + std::to_string(pt->KPerThread);
|
||||
|
||||
out += " -DCK_PARAM_M1N1ThreadClusterM10=" + std::to_string(pt->M1N1ThreadClusterM10) +
|
||||
" -DCK_PARAM_M1N1ThreadClusterN10=" + std::to_string(pt->M1N1ThreadClusterN10) +
|
||||
" -DCK_PARAM_M1N1ThreadClusterM11=" + std::to_string(pt->M1N1ThreadClusterM11) +
|
||||
" -DCK_PARAM_M1N1ThreadClusterN11=" + std::to_string(pt->M1N1ThreadClusterN11);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1=" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_M1=" +
|
||||
std::to_string(pt->ABlockTransferDstScalarPerVector_M1);
|
||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1=" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_N1=" +
|
||||
std::to_string(pt->BBlockTransferDstScalarPerVector_N1);
|
||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
} // namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
online_compile::Handle* handle,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace ck_driver;
|
||||
using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw;
|
||||
using size_t = std::size_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
|
||||
// hasDoubleTailKBlockLoop
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
const auto a_k_m_grid_desc = descs[I0];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock);
|
||||
const bool hasMainKBlockLoop = ((K + tunable->KPerBlock) / (2 * tunable->KPerBlock) > 1);
|
||||
const bool hasDoubleTailKBlockLoop = ((K / tunable->KPerBlock) % 2 == 0);
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
||||
// workspace API
|
||||
DeviceMem workspace_buf(4096);
|
||||
|
||||
void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
||||
void* b_k_n0_n1_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
||||
void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
||||
void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
||||
|
||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
||||
|
||||
std::string program_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_dlops_nchw";
|
||||
|
||||
std::string param = " -std=c++17 ";
|
||||
std::string network_config;
|
||||
|
||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() + " " +
|
||||
get_definition_string_from_tunable(tunable) +
|
||||
" -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) +
|
||||
" -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop);
|
||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
||||
get_network_config_string_from_tunable(tunable) + "_" +
|
||||
std::to_string(hasMainKBlockLoop) + "_" +
|
||||
std::to_string(hasDoubleTailKBlockLoop);
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer1, timer2;
|
||||
std::string kernel_name;
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare";
|
||||
auto network_config_1 = network_config + "_1";
|
||||
|
||||
timer1.Start();
|
||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
a_k_m0_m1_grid_desc_dev_buf,
|
||||
b_k_n0_n1_grid_desc_dev_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf);
|
||||
timer1.End();
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw";
|
||||
auto network_config_2 = network_config + "_2";
|
||||
|
||||
timer2.Start();
|
||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
||||
(const void*)(a_k_m0_m1_grid_desc_dev_buf),
|
||||
(const void*)(b_k_n0_n1_grid_desc_dev_buf),
|
||||
(const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf),
|
||||
(const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf));
|
||||
timer2.End();
|
||||
|
||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
||||
}
|
||||
|
||||
{
|
||||
auto ave_time1 =
|
||||
std::accumulate(
|
||||
std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus<float>{}) /
|
||||
(nrepeat - 1);
|
||||
auto ave_time2 =
|
||||
std::accumulate(
|
||||
std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus<float>{}) /
|
||||
(nrepeat - 1);
|
||||
|
||||
const auto N = in_n_c_hi_wi_lengths[I0];
|
||||
const auto C = in_n_c_hi_wi_lengths[I1];
|
||||
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
||||
|
||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
||||
};
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -1,386 +0,0 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "online_driver_common.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw {
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_network_config_string_from_types()
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::string out;
|
||||
|
||||
out += std::to_string(get_datatype_enum_from_type<TInWei>::value) + "_" +
|
||||
std::to_string(get_datatype_enum_from_type<TAcc>::value) + "_" +
|
||||
std::to_string(get_datatype_enum_from_type<TOut>::value);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt)
|
||||
{
|
||||
std::string out("TUN_");
|
||||
|
||||
out += std::to_string(pt->BlockSize) + "_";
|
||||
|
||||
out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" +
|
||||
std::to_string(pt->KPerBlock) + "_";
|
||||
out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" +
|
||||
std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" +
|
||||
std::to_string(pt->K1) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_";
|
||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_";
|
||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_definition_string_from_types()
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::string out;
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TInWei>::value) +
|
||||
" -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TAcc>::value) +
|
||||
" -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TOut>::value);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt)
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
||||
|
||||
out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) +
|
||||
" -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) +
|
||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
||||
out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) +
|
||||
" -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) +
|
||||
" -DCK_PARAM_K1=" + std::to_string(pt->K1) +
|
||||
" -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) +
|
||||
" -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" +
|
||||
std::to_string(pt->ABlockTransferDstScalarPerVector_K1);
|
||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" +
|
||||
std::to_string(pt->BBlockTransferDstScalarPerVector_K1);
|
||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
online_compile::Handle* handle,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace ck_driver;
|
||||
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
||||
using size_t = std::size_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
const auto n = in_n_c_hi_wi_desc.GetLength(I0);
|
||||
const auto c = in_n_c_hi_wi_desc.GetLength(I1);
|
||||
const auto hi = in_n_c_hi_wi_desc.GetLength(I2);
|
||||
const auto wi = in_n_c_hi_wi_desc.GetLength(I3);
|
||||
const auto k = wei_k_c_y_x_desc.GetLength(I0);
|
||||
const auto y = wei_k_c_y_x_desc.GetLength(I2);
|
||||
const auto x = wei_k_c_y_x_desc.GetLength(I3);
|
||||
const auto ho = out_n_k_ho_wo_desc.GetLength(I2);
|
||||
const auto wo = out_n_k_ho_wo_desc.GetLength(I3);
|
||||
|
||||
const auto M = k;
|
||||
const auto N = n * ho * wo;
|
||||
const auto K = c * y * x;
|
||||
const auto K0 = K / tunable->K1;
|
||||
|
||||
const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock);
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
||||
// workspace API
|
||||
DeviceMem workspace_buf(4096);
|
||||
|
||||
void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
||||
void* b_k_n0_n1_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
||||
void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
||||
void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
||||
|
||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
||||
|
||||
std::string program_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nchw";
|
||||
|
||||
std::string param = " -std=c++17 ";
|
||||
std::string network_config;
|
||||
|
||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() + " " + " -DCK_USE_AMD_XDLOPS" +
|
||||
get_definition_string_from_tunable(tunable);
|
||||
|
||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
||||
get_network_config_string_from_tunable(tunable);
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer1, timer2;
|
||||
std::string kernel_name;
|
||||
|
||||
kernel_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare";
|
||||
auto network_config_1 = network_config + "_1";
|
||||
|
||||
timer1.Start();
|
||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
a_k_m0_m1_grid_desc_dev_buf,
|
||||
b_k_n0_n1_grid_desc_dev_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf);
|
||||
timer1.End();
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw";
|
||||
auto network_config_2 = network_config + "_2";
|
||||
|
||||
timer2.Start();
|
||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
||||
(const void*)(a_k_m0_m1_grid_desc_dev_buf),
|
||||
(const void*)(b_k_n0_n1_grid_desc_dev_buf),
|
||||
(const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf),
|
||||
(const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf));
|
||||
timer2.End();
|
||||
|
||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
||||
}
|
||||
|
||||
{
|
||||
auto ave_time1 =
|
||||
std::accumulate(
|
||||
std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus<float>{}) /
|
||||
(nrepeat - 1);
|
||||
auto ave_time2 =
|
||||
std::accumulate(
|
||||
std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus<float>{}) /
|
||||
(nrepeat - 1);
|
||||
|
||||
const auto N = in_n_c_hi_wi_lengths[I0];
|
||||
const auto C = in_n_c_hi_wi_lengths[I1];
|
||||
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
||||
|
||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
||||
};
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -1,389 +0,0 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "online_driver_common.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk {
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_network_config_string_from_types()
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::string out;
|
||||
|
||||
out += std::to_string(get_datatype_enum_from_type<TInWei>::value) + "_" +
|
||||
std::to_string(get_datatype_enum_from_type<TAcc>::value) + "_" +
|
||||
std::to_string(get_datatype_enum_from_type<TOut>::value);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt)
|
||||
{
|
||||
std::string out("TUN_");
|
||||
|
||||
out += std::to_string(pt->BlockSize) + "_";
|
||||
|
||||
out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" +
|
||||
std::to_string(pt->KPerBlock) + "_";
|
||||
out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" +
|
||||
std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" +
|
||||
std::to_string(pt->K1) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_";
|
||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_";
|
||||
|
||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_";
|
||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_";
|
||||
|
||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
template <typename TInWei, typename TAcc, typename TOut>
|
||||
static std::string get_definition_string_from_types()
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::string out;
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TInWei>::value) +
|
||||
" -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TAcc>::value) +
|
||||
" -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TOut>::value);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
static std::string
|
||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt)
|
||||
{
|
||||
std::string out;
|
||||
|
||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
||||
|
||||
out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) +
|
||||
" -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) +
|
||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
||||
out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) +
|
||||
" -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) +
|
||||
" -DCK_PARAM_K1=" + std::to_string(pt->K1) +
|
||||
" -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) +
|
||||
" -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" +
|
||||
std::to_string(pt->ABlockTransferDstScalarPerVector_K1);
|
||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]);
|
||||
|
||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]);
|
||||
|
||||
out +=
|
||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" +
|
||||
std::to_string(pt->BBlockTransferDstScalarPerVector_K1);
|
||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," +
|
||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]);
|
||||
|
||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
||||
|
||||
return (out);
|
||||
};
|
||||
|
||||
} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
online_compile::Handle* handle,
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
|
||||
using size_t = std::size_t;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
|
||||
// hasDoubleTailKBlockLoop
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
const auto n = in_n_hi_wi_c_desc.GetLength(I0);
|
||||
const auto hi = in_n_hi_wi_c_desc.GetLength(I1);
|
||||
const auto wi = in_n_hi_wi_c_desc.GetLength(I2);
|
||||
const auto c = in_n_hi_wi_c_desc.GetLength(I3);
|
||||
|
||||
const auto k = wei_k_y_x_c_desc.GetLength(I0);
|
||||
const auto y = wei_k_y_x_c_desc.GetLength(I1);
|
||||
const auto x = wei_k_y_x_c_desc.GetLength(I2);
|
||||
|
||||
const auto ho = out_n_ho_wo_k_desc.GetLength(I1);
|
||||
const auto wo = out_n_ho_wo_k_desc.GetLength(I2);
|
||||
|
||||
const auto M = k;
|
||||
const auto N = n * ho * wo;
|
||||
const auto K = c * y * x;
|
||||
const auto K0 = K / tunable->K1;
|
||||
|
||||
const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock);
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem in_n_hi_wi_c_dev_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_dev_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_dev_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_dev_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_dev_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_dev_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
||||
// workspace API
|
||||
DeviceMem workspace_buf(4096);
|
||||
|
||||
void* a_k0_m_k1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
||||
void* b_k0_n_k1_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
||||
void* c_m0_m1_m2_n_grid_desc_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
||||
void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf =
|
||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
||||
|
||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
||||
|
||||
std::string program_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nhwc";
|
||||
|
||||
std::string param = " -std=c++17 ";
|
||||
std::string network_config;
|
||||
|
||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() + " -DCK_USE_AMD_XDLOPS ";
|
||||
param += get_definition_string_from_tunable(tunable);
|
||||
|
||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
||||
get_network_config_string_from_tunable(tunable);
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer1, timer2;
|
||||
std::string kernel_name;
|
||||
|
||||
kernel_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare";
|
||||
auto network_config_1 = network_config + "_1";
|
||||
|
||||
timer1.Start();
|
||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I0]),
|
||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I1]),
|
||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I2]),
|
||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I3]),
|
||||
static_cast<index_t>(wei_k_y_x_c_lengths[I0]),
|
||||
static_cast<index_t>(wei_k_y_x_c_lengths[I1]),
|
||||
static_cast<index_t>(wei_k_y_x_c_lengths[I2]),
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
a_k0_m_k1_grid_desc_dev_buf,
|
||||
b_k0_n_k1_grid_desc_dev_buf,
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf);
|
||||
timer1.End();
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk";
|
||||
auto network_config_2 = network_config + "_2";
|
||||
|
||||
timer2.Start();
|
||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
||||
reinterpret_cast<const TInWei*>(in_n_hi_wi_c_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const TInWei*>(wei_k_y_x_c_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<TOut*>(out_n_ho_wo_k_dev_buf.GetDeviceBuffer()),
|
||||
(const void*)(a_k0_m_k1_grid_desc_dev_buf),
|
||||
(const void*)(b_k0_n_k1_grid_desc_dev_buf),
|
||||
(const void*)(c_m0_m1_m2_n_grid_desc_dev_buf),
|
||||
(const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf));
|
||||
timer2.End();
|
||||
|
||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
||||
}
|
||||
|
||||
{
|
||||
auto ave_time1 =
|
||||
std::accumulate(
|
||||
std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus<float>{}) /
|
||||
(nrepeat - 1);
|
||||
auto ave_time2 =
|
||||
std::accumulate(
|
||||
std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus<float>{}) /
|
||||
(nrepeat - 1);
|
||||
|
||||
const auto N = in_n_hi_wi_c_lengths[I0];
|
||||
const auto C = in_n_hi_wi_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time2;
|
||||
|
||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
||||
};
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_dev_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
#pragma once
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "handle.hpp"
|
||||
#include "online_driver_common.hpp"
|
||||
#include "convolution_problem_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
online_compile::Handle* handle,
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
const ck_driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace ck_driver;
|
||||
using size_t = std::size_t;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
ConvolutionProblemDescriptor conv_problem_desc{in_n_c_hi_wi_lengths[I0],
|
||||
out_n_k_ho_wo_lengths[I1],
|
||||
in_n_c_hi_wi_lengths[I1],
|
||||
wei_k_c_y_x_lengths[I2],
|
||||
wei_k_c_y_x_lengths[I3],
|
||||
in_n_c_hi_wi_lengths[I2],
|
||||
in_n_c_hi_wi_lengths[I3],
|
||||
out_n_k_ho_wo_lengths[I2],
|
||||
out_n_k_ho_wo_lengths[I3],
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
get_datatype_enum_from_type<TInWei>::value,
|
||||
get_datatype_enum_from_type<TInWei>::value,
|
||||
get_datatype_enum_from_type<TOut>::value};
|
||||
|
||||
if(!ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::IsValidCompileParameter(conv_problem_desc,
|
||||
compile_param))
|
||||
{
|
||||
throw std::runtime_error("wrong! IsValidCompileParameter fail");
|
||||
}
|
||||
|
||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
// workspace is used for save transformed tensor descritpors created by prepare kernel
|
||||
DeviceMem workspace_dev_buf(
|
||||
ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetWorkSpaceSize(conv_problem_desc, compile_param));
|
||||
|
||||
const auto block_size = std::size_t(
|
||||
ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetBlockSize(conv_problem_desc, compile_param));
|
||||
|
||||
const auto grid_size = std::size_t(
|
||||
ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetGridSize(conv_problem_desc, compile_param));
|
||||
|
||||
const std::vector<size_t> vld1 = {1, 1, 1};
|
||||
const std::vector<size_t> vgd1 = {1, 1, 1};
|
||||
|
||||
const std::vector<size_t> vld2 = {static_cast<size_t>(block_size), 1, 1};
|
||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * block_size), 1, 1};
|
||||
|
||||
std::string program_name =
|
||||
"dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw";
|
||||
|
||||
std::string compile_param_string = get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString();
|
||||
std::string network_config = compile_param_string;
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
for(index_t i = 0; i < nrepeat + 1; ++i)
|
||||
{
|
||||
KernelTimer timer1, timer2;
|
||||
std::string kernel_name;
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare";
|
||||
auto network_config_1 = network_config + "_1";
|
||||
|
||||
timer1.Start();
|
||||
handle->AddKernel(algo_name,
|
||||
network_config_1,
|
||||
program_name,
|
||||
kernel_name,
|
||||
vld1,
|
||||
vgd1,
|
||||
compile_param_string)(static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
||||
conv_strides[I0],
|
||||
conv_strides[I1],
|
||||
conv_dilations[I0],
|
||||
conv_dilations[I1],
|
||||
in_left_pads[I0],
|
||||
in_left_pads[I1],
|
||||
in_right_pads[I0],
|
||||
in_right_pads[I1],
|
||||
(void*)(workspace_dev_buf.GetDeviceBuffer()));
|
||||
timer1.End();
|
||||
|
||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw";
|
||||
auto network_config_2 = network_config + "_2";
|
||||
|
||||
timer2.Start();
|
||||
handle->AddKernel(algo_name,
|
||||
network_config_2,
|
||||
program_name,
|
||||
kernel_name,
|
||||
vld2,
|
||||
vgd2,
|
||||
compile_param_string)(
|
||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
||||
(const void*)(workspace_dev_buf.GetDeviceBuffer()));
|
||||
timer2.End();
|
||||
|
||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
||||
}
|
||||
|
||||
{
|
||||
auto ave_time1 =
|
||||
std::accumulate(
|
||||
std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus<float>{}) /
|
||||
nrepeat;
|
||||
auto ave_time2 =
|
||||
std::accumulate(
|
||||
std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus<float>{}) /
|
||||
nrepeat;
|
||||
|
||||
float perf = (float)(conv_problem_desc.CalculateFlop()) /
|
||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
||||
|
||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
||||
};
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -10,6 +10,8 @@ set(HOST_TENSOR_SOURCE
|
||||
## the library target
|
||||
add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE})
|
||||
|
||||
target_include_directories(host_tensor SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||
|
||||
target_link_libraries(host_tensor PRIVATE hip::device)
|
||||
target_link_libraries(host_tensor INTERFACE hip::host)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#ifndef CONV_COMMON_HPP
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
{
|
||||
@@ -19,8 +19,8 @@ template <typename... InDesc,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
const ck::DynamicTensorDescriptor<InDesc...>& in_desc,
|
||||
const ck::DynamicTensorDescriptor<WeiDesc...>& wei_desc,
|
||||
const ck::TensorDescriptor<InDesc...>& in_desc,
|
||||
const ck::TensorDescriptor<WeiDesc...>& wei_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations conv_dilations,
|
||||
const LeftPads& left_pads,
|
||||
@@ -57,12 +57,12 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
|
||||
const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
|
||||
|
||||
return make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
|
||||
}
|
||||
|
||||
template <class InDesc, class WeiDesc, class OutDesc>
|
||||
constexpr std::size_t
|
||||
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
|
||||
calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
|
||||
@@ -34,24 +34,16 @@ struct KernelTimer
|
||||
using device_stream_t = hipStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
void launch_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
hipStream_t stream_id,
|
||||
Args... args)
|
||||
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
hipStream_t stream_id = nullptr;
|
||||
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(F kernel,
|
||||
int nrepeat,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
hipStream_t stream_id,
|
||||
Args... args)
|
||||
float launch_and_time_kernel(
|
||||
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
@@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel,
|
||||
|
||||
printf("Warm up\n");
|
||||
|
||||
hipStream_t stream_id = nullptr;
|
||||
|
||||
// warm up
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user