mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 08:18:26 +00:00
Merge pull request #8 from ROCmSoftwarePlatform/miopen_downstream_init_integration
[ROCm/composable_kernel commit: ccc4a1d365]
This commit is contained in:
3
.clang-tidy
Normal file
3
.clang-tidy
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
CheckOptions:
|
||||||
|
- key: bugprone-reserved-identifier.AllowedIdentifiers
|
||||||
|
value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__'
|
||||||
164
CMakeLists.txt
164
CMakeLists.txt
@@ -1,10 +1,9 @@
|
|||||||
cmake_minimum_required(VERSION 2.8.3)
|
cmake_minimum_required(VERSION 3.5)
|
||||||
project(modular_convolution)
|
project(composable_kernel)
|
||||||
|
|
||||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||||
|
|
||||||
include(TargetFlags)
|
include(CheckCXXCompilerFlag)
|
||||||
include(AddKernels)
|
|
||||||
|
|
||||||
## C++
|
## C++
|
||||||
enable_language(CXX)
|
enable_language(CXX)
|
||||||
@@ -39,4 +38,161 @@ link_libraries(${OpenMP_pthread_LIBRARY})
|
|||||||
find_package(HIP REQUIRED)
|
find_package(HIP REQUIRED)
|
||||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||||
|
|
||||||
|
## half
|
||||||
|
#find_path(HALF_INCLUDE_DIR half.hpp)
|
||||||
|
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
|
||||||
|
|
||||||
|
# CMAKE_CXX_FLAGS
|
||||||
|
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||||
|
if(BUILD_DEV)
|
||||||
|
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
|
||||||
|
endif()
|
||||||
|
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||||
|
|
||||||
|
## tidy
|
||||||
|
include(EnableCompilerWarnings)
|
||||||
|
set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
|
||||||
|
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
|
||||||
|
set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
|
||||||
|
# Enable tidy on hip
|
||||||
|
elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU")
|
||||||
|
set(MIOPEN_TIDY_ERRORS ALL)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
include(ClangTidy)
|
||||||
|
enable_clang_tidy(
|
||||||
|
CHECKS
|
||||||
|
*
|
||||||
|
-abseil-*
|
||||||
|
-android-cloexec-fopen
|
||||||
|
# Yea we shouldn't be using rand()
|
||||||
|
-cert-msc30-c
|
||||||
|
-bugprone-exception-escape
|
||||||
|
-bugprone-macro-parentheses
|
||||||
|
-cert-env33-c
|
||||||
|
-cert-msc32-c
|
||||||
|
-cert-msc50-cpp
|
||||||
|
-cert-msc51-cpp
|
||||||
|
-cert-dcl37-c
|
||||||
|
-cert-dcl51-cpp
|
||||||
|
-clang-analyzer-alpha.core.CastToStruct
|
||||||
|
-clang-analyzer-optin.performance.Padding
|
||||||
|
-clang-diagnostic-deprecated-declarations
|
||||||
|
-clang-diagnostic-extern-c-compat
|
||||||
|
-clang-diagnostic-unused-command-line-argument
|
||||||
|
-cppcoreguidelines-avoid-c-arrays
|
||||||
|
-cppcoreguidelines-avoid-magic-numbers
|
||||||
|
-cppcoreguidelines-explicit-virtual-functions
|
||||||
|
-cppcoreguidelines-init-variables
|
||||||
|
-cppcoreguidelines-macro-usage
|
||||||
|
-cppcoreguidelines-non-private-member-variables-in-classes
|
||||||
|
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
|
||||||
|
-cppcoreguidelines-pro-bounds-constant-array-index
|
||||||
|
-cppcoreguidelines-pro-bounds-pointer-arithmetic
|
||||||
|
-cppcoreguidelines-pro-type-member-init
|
||||||
|
-cppcoreguidelines-pro-type-reinterpret-cast
|
||||||
|
-cppcoreguidelines-pro-type-union-access
|
||||||
|
-cppcoreguidelines-pro-type-vararg
|
||||||
|
-cppcoreguidelines-special-member-functions
|
||||||
|
-fuchsia-*
|
||||||
|
-google-explicit-constructor
|
||||||
|
-google-readability-braces-around-statements
|
||||||
|
-google-readability-todo
|
||||||
|
-google-runtime-int
|
||||||
|
-google-runtime-references
|
||||||
|
-hicpp-vararg
|
||||||
|
-hicpp-braces-around-statements
|
||||||
|
-hicpp-explicit-conversions
|
||||||
|
-hicpp-named-parameter
|
||||||
|
-hicpp-no-array-decay
|
||||||
|
# We really shouldn't use bitwise operators with signed integers, but
|
||||||
|
# opencl leaves us no choice
|
||||||
|
-hicpp-avoid-c-arrays
|
||||||
|
-hicpp-signed-bitwise
|
||||||
|
-hicpp-special-member-functions
|
||||||
|
-hicpp-uppercase-literal-suffix
|
||||||
|
-hicpp-use-auto
|
||||||
|
-hicpp-use-equals-default
|
||||||
|
-hicpp-use-override
|
||||||
|
-llvm-header-guard
|
||||||
|
-llvm-include-order
|
||||||
|
#-llvmlibc-*
|
||||||
|
-llvmlibc-restrict-system-libc-headers
|
||||||
|
-llvmlibc-callee-namespace
|
||||||
|
-llvmlibc-implementation-in-namespace
|
||||||
|
-llvm-else-after-return
|
||||||
|
-llvm-qualified-auto
|
||||||
|
-misc-misplaced-const
|
||||||
|
-misc-non-private-member-variables-in-classes
|
||||||
|
-misc-no-recursion
|
||||||
|
-modernize-avoid-bind
|
||||||
|
-modernize-avoid-c-arrays
|
||||||
|
-modernize-pass-by-value
|
||||||
|
-modernize-use-auto
|
||||||
|
-modernize-use-default-member-init
|
||||||
|
-modernize-use-equals-default
|
||||||
|
-modernize-use-trailing-return-type
|
||||||
|
-modernize-use-transparent-functors
|
||||||
|
-performance-unnecessary-value-param
|
||||||
|
-readability-braces-around-statements
|
||||||
|
-readability-else-after-return
|
||||||
|
# we are not ready to use it, but very useful
|
||||||
|
-readability-function-cognitive-complexity
|
||||||
|
-readability-isolate-declaration
|
||||||
|
-readability-magic-numbers
|
||||||
|
-readability-named-parameter
|
||||||
|
-readability-uppercase-literal-suffix
|
||||||
|
-readability-convert-member-functions-to-static
|
||||||
|
-readability-qualified-auto
|
||||||
|
-readability-redundant-string-init
|
||||||
|
# too many narrowing conversions in our code
|
||||||
|
-bugprone-narrowing-conversions
|
||||||
|
-cppcoreguidelines-narrowing-conversions
|
||||||
|
-altera-struct-pack-align
|
||||||
|
-cppcoreguidelines-prefer-member-initializer
|
||||||
|
|
||||||
|
${MIOPEN_TIDY_CHECKS}
|
||||||
|
${MIOPEN_TIDY_ERRORS}
|
||||||
|
HEADER_FILTER
|
||||||
|
"\.hpp$"
|
||||||
|
EXTRA_ARGS
|
||||||
|
-DMIOPEN_USE_CLANG_TIDY
|
||||||
|
)
|
||||||
|
|
||||||
|
include(CppCheck)
|
||||||
|
enable_cppcheck(
|
||||||
|
CHECKS
|
||||||
|
warning
|
||||||
|
style
|
||||||
|
performance
|
||||||
|
portability
|
||||||
|
SUPPRESS
|
||||||
|
ConfigurationNotChecked
|
||||||
|
constStatement
|
||||||
|
duplicateCondition
|
||||||
|
noExplicitConstructor
|
||||||
|
passedByValue
|
||||||
|
preprocessorErrorDirective
|
||||||
|
shadowVariable
|
||||||
|
unusedFunction
|
||||||
|
unusedPrivateFunction
|
||||||
|
unusedStructMember
|
||||||
|
unmatchedSuppression
|
||||||
|
FORCE
|
||||||
|
SOURCES
|
||||||
|
host/host_tensor/src
|
||||||
|
host/driver_offline/src
|
||||||
|
composable_kernel/src/kernel_wrapper
|
||||||
|
INCLUDE
|
||||||
|
host/host_tensor/include
|
||||||
|
host/solver/include
|
||||||
|
host/driver_offline/include
|
||||||
|
composable_kernel/include/*
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||||
|
DEFINE
|
||||||
|
CPPCHECK=1
|
||||||
|
__linux__=1
|
||||||
|
)
|
||||||
|
|
||||||
add_subdirectory(host)
|
add_subdirectory(host)
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -78,7 +78,7 @@ InLeftPads size 2, {1, 1, }
|
|||||||
InRightPads size 2, {1, 1, }
|
InRightPads size 2, {1, 1, }
|
||||||
ConvStrides size 2, {2, 2, }
|
ConvStrides size 2, {2, 2, }
|
||||||
ConvDilations size 2, {1, 1, }
|
ConvDilations size 2, {1, 1, }
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||||
a_k0_m_k1_grid_desc{216, 256, 8}
|
a_k0_m_k1_grid_desc{216, 256, 8}
|
||||||
b_k0_n_k1_grid_desc{216, 165888, 8}
|
b_k0_n_k1_grid_desc{216, 165888, 8}
|
||||||
c_m_n_grid_desc{ 256, 165888}
|
c_m_n_grid_desc{ 256, 165888}
|
||||||
@@ -100,7 +100,7 @@ InLeftPads size 2, {1, 1, }
|
|||||||
InRightPads size 2, {1, 1, }
|
InRightPads size 2, {1, 1, }
|
||||||
ConvStrides size 2, {1, 1, }
|
ConvStrides size 2, {1, 1, }
|
||||||
ConvDilations size 2, {1, 1, }
|
ConvDilations size 2, {1, 1, }
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||||
a_k0_m_k1_grid_desc{288, 1024, 8}
|
a_k0_m_k1_grid_desc{288, 1024, 8}
|
||||||
b_k0_n_k1_grid_desc{288, 50176, 8}
|
b_k0_n_k1_grid_desc{288, 50176, 8}
|
||||||
c_m_n_grid_desc{ 1024, 50176}
|
c_m_n_grid_desc{ 1024, 50176}
|
||||||
@@ -122,7 +122,7 @@ InLeftPads size 2, {1, 1, }
|
|||||||
InRightPads size 2, {1, 1, }
|
InRightPads size 2, {1, 1, }
|
||||||
ConvStrides size 2, {2, 2, }
|
ConvStrides size 2, {2, 2, }
|
||||||
ConvDilations size 2, {1, 1, }
|
ConvDilations size 2, {1, 1, }
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||||
a_k0_m_k1_grid_desc{216, 165888, 8}
|
a_k0_m_k1_grid_desc{216, 165888, 8}
|
||||||
b_k0_n_k1_grid_desc{216, 256, 8}
|
b_k0_n_k1_grid_desc{216, 256, 8}
|
||||||
c_m_n_grid_desc{ 165888, 256}
|
c_m_n_grid_desc{ 165888, 256}
|
||||||
@@ -144,7 +144,7 @@ InLeftPads size 2, {1, 1, }
|
|||||||
InRightPads size 2, {1, 1, }
|
InRightPads size 2, {1, 1, }
|
||||||
ConvStrides size 2, {1, 1, }
|
ConvStrides size 2, {1, 1, }
|
||||||
ConvDilations size 2, {1, 1, }
|
ConvDilations size 2, {1, 1, }
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||||
a_k0_m_k1_grid_desc{288, 50176, 8}
|
a_k0_m_k1_grid_desc{288, 50176, 8}
|
||||||
b_k0_n_k1_grid_desc{288, 1024, 8}
|
b_k0_n_k1_grid_desc{288, 1024, 8}
|
||||||
c_m_n_grid_desc{ 50176, 1024}
|
c_m_n_grid_desc{ 50176, 1024}
|
||||||
@@ -166,7 +166,7 @@ InLeftPads size 2, {1, 1, }
|
|||||||
InRightPads size 2, {1, 1, }
|
InRightPads size 2, {1, 1, }
|
||||||
ConvStrides size 2, {1, 1, }
|
ConvStrides size 2, {1, 1, }
|
||||||
ConvDilations size 2, {1, 1, }
|
ConvDilations size 2, {1, 1, }
|
||||||
device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||||
a_k0_m_k1_grid_desc{288, 50176, 8}
|
a_k0_m_k1_grid_desc{288, 50176, 8}
|
||||||
b_k0_n_k1_grid_desc{288, 1024, 8}
|
b_k0_n_k1_grid_desc{288, 1024, 8}
|
||||||
c_m_n_grid_desc{ 50176, 1024}
|
c_m_n_grid_desc{ 50176, 1024}
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
|
|
||||||
function(add_kernels SRC_DIR KERNEL_FILES)
|
|
||||||
set(INIT_KERNELS_LIST)
|
|
||||||
set(KERNELS_DECLS)
|
|
||||||
foreach(KERNEL_FILE ${KERNEL_FILES})
|
|
||||||
if("${CMAKE_VERSION}" VERSION_LESS 3.0)
|
|
||||||
configure_file(${KERNEL_FILE} ${KERNEL_FILE}.delete)
|
|
||||||
else()
|
|
||||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${KERNEL_FILE})
|
|
||||||
endif()
|
|
||||||
get_filename_component(BASE_NAME ${KERNEL_FILE} NAME_WE)
|
|
||||||
string(TOUPPER "${BASE_NAME}" KEY_NAME)
|
|
||||||
string(MAKE_C_IDENTIFIER "${KEY_NAME}" VAR_NAME)
|
|
||||||
string(APPEND KERNELS_DECLS "extern const size_t APP_KERNEL_${VAR_NAME}_SIZE;\n")
|
|
||||||
string(APPEND KERNELS_DECLS "extern const unsigned char APP_KERNEL_${VAR_NAME}[];\n")
|
|
||||||
list(APPEND INIT_KERNELS_LIST " { \"${KEY_NAME}\", std::string(reinterpret_cast<const char*>(APP_KERNEL_${VAR_NAME}), APP_KERNEL_${VAR_NAME}_SIZE) }")
|
|
||||||
endforeach()
|
|
||||||
string(REPLACE ";" ",\n" INIT_KERNELS "${INIT_KERNELS_LIST}")
|
|
||||||
configure_file(${SRC_DIR}/kernel.cpp.in ${PROJECT_BINARY_DIR}/kernel.cpp)
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
function(add_kernel_includes SRC_DIR KERNEL_FILES)
|
|
||||||
set(INIT_KERNELS_LIST)
|
|
||||||
foreach(KERNEL_FILE ${KERNEL_FILES})
|
|
||||||
if("${CMAKE_VERSION}" VERSION_LESS 3.0)
|
|
||||||
configure_file(${KERNEL_FILE} ${KERNEL_FILE}.delete)
|
|
||||||
else()
|
|
||||||
set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${KERNEL_FILE})
|
|
||||||
endif()
|
|
||||||
get_filename_component(BASE_NAME ${KERNEL_FILE} NAME_WE)
|
|
||||||
get_filename_component(FILE_NAME ${KERNEL_FILE} NAME)
|
|
||||||
string(TOUPPER "${BASE_NAME}" KEY_NAME)
|
|
||||||
string(MAKE_C_IDENTIFIER "${KEY_NAME}" VAR_NAME)
|
|
||||||
list(APPEND INIT_KERNELS_LIST " { \"${FILE_NAME}\", std::string(reinterpret_cast<const char*>(${VAR_NAME}), ${VAR_NAME}_SIZE) }")
|
|
||||||
endforeach()
|
|
||||||
string(REPLACE ";" ",\n" INIT_KERNELS "${INIT_KERNELS_LIST}")
|
|
||||||
configure_file(${SRC_DIR}/kernel_includes.cpp.in ${PROJECT_BINARY_DIR}/kernel_includes.cpp)
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +24,11 @@
|
|||||||
#
|
#
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
||||||
set(ADD_KERNELS_SOURCE include_inliner.cpp addkernels.cpp)
|
if(NOT TARGET analyze)
|
||||||
|
add_custom_target(analyze)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_executable(addkernels EXCLUDE_FROM_ALL ${ADD_KERNELS_SOURCE})
|
function(mark_as_analyzer)
|
||||||
|
add_dependencies(analyze ${ARGN})
|
||||||
|
endfunction()
|
||||||
|
|
||||||
162
cmake/ClangTidy.cmake
Normal file
162
cmake/ClangTidy.cmake
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
################################################################################
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
#
|
||||||
|
################################################################################
|
||||||
|
include(CMakeParseArguments)
|
||||||
|
include(Analyzers)
|
||||||
|
|
||||||
|
get_filename_component(CLANG_TIDY_EXE_HINT "${CMAKE_CXX_COMPILER}" PATH)
|
||||||
|
|
||||||
|
find_program(CLANG_TIDY_EXE
|
||||||
|
NAMES
|
||||||
|
clang-tidy
|
||||||
|
clang-tidy-5.0
|
||||||
|
clang-tidy-4.0
|
||||||
|
clang-tidy-3.9
|
||||||
|
clang-tidy-3.8
|
||||||
|
clang-tidy-3.7
|
||||||
|
clang-tidy-3.6
|
||||||
|
clang-tidy-3.5
|
||||||
|
HINTS
|
||||||
|
${CLANG_TIDY_EXE_HINT}
|
||||||
|
PATH_SUFFIXES
|
||||||
|
compiler/bin
|
||||||
|
PATHS
|
||||||
|
/opt/rocm/llvm/bin
|
||||||
|
/opt/rocm/hcc
|
||||||
|
/usr/local/opt/llvm/bin
|
||||||
|
)
|
||||||
|
|
||||||
|
function(find_clang_tidy_version VAR)
|
||||||
|
execute_process(COMMAND ${CLANG_TIDY_EXE} -version OUTPUT_VARIABLE VERSION_OUTPUT)
|
||||||
|
separate_arguments(VERSION_OUTPUT_LIST UNIX_COMMAND "${VERSION_OUTPUT}")
|
||||||
|
list(FIND VERSION_OUTPUT_LIST "version" VERSION_INDEX)
|
||||||
|
if(VERSION_INDEX GREATER 0)
|
||||||
|
math(EXPR VERSION_INDEX "${VERSION_INDEX} + 1")
|
||||||
|
list(GET VERSION_OUTPUT_LIST ${VERSION_INDEX} VERSION)
|
||||||
|
set(${VAR} ${VERSION} PARENT_SCOPE)
|
||||||
|
else()
|
||||||
|
set(${VAR} "0.0" PARENT_SCOPE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
if( NOT CLANG_TIDY_EXE )
|
||||||
|
message( STATUS "Clang tidy not found" )
|
||||||
|
set(CLANG_TIDY_VERSION "0.0")
|
||||||
|
else()
|
||||||
|
find_clang_tidy_version(CLANG_TIDY_VERSION)
|
||||||
|
message( STATUS "Clang tidy found: ${CLANG_TIDY_VERSION}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
set(CLANG_TIDY_FIXIT_DIR ${CMAKE_BINARY_DIR}/fixits)
|
||||||
|
file(MAKE_DIRECTORY ${CLANG_TIDY_FIXIT_DIR})
|
||||||
|
set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CLANG_TIDY_FIXIT_DIR})
|
||||||
|
|
||||||
|
macro(enable_clang_tidy)
|
||||||
|
set(options ANALYZE_TEMPORARY_DTORS ALL)
|
||||||
|
set(oneValueArgs HEADER_FILTER)
|
||||||
|
set(multiValueArgs CHECKS ERRORS EXTRA_ARGS)
|
||||||
|
|
||||||
|
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
string(REPLACE ";" "," CLANG_TIDY_CHECKS "${PARSE_CHECKS}")
|
||||||
|
string(REPLACE ";" "," CLANG_TIDY_ERRORS "${PARSE_ERRORS}")
|
||||||
|
set(CLANG_TIDY_EXTRA_ARGS)
|
||||||
|
foreach(ARG ${PARSE_EXTRA_ARGS})
|
||||||
|
list(APPEND CLANG_TIDY_EXTRA_ARGS "-extra-arg=${ARG}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
set(CLANG_TIDY_ALL)
|
||||||
|
if(PARSE_ALL)
|
||||||
|
set(CLANG_TIDY_ALL ALL)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
message(STATUS "Clang tidy checks: ${CLANG_TIDY_CHECKS}")
|
||||||
|
|
||||||
|
if (${PARSE_ANALYZE_TEMPORARY_DTORS})
|
||||||
|
set(CLANG_TIDY_ANALYZE_TEMPORARY_DTORS "-analyze-temporary-dtors")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0")
|
||||||
|
set(CLANG_TIDY_ERRORS_ARG "")
|
||||||
|
else()
|
||||||
|
set(CLANG_TIDY_ERRORS_ARG "-warnings-as-errors='${CLANG_TIDY_ERRORS}'")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0")
|
||||||
|
set(CLANG_TIDY_QUIET_ARG "")
|
||||||
|
else()
|
||||||
|
set(CLANG_TIDY_QUIET_ARG "-quiet")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(PARSE_HEADER_FILTER)
|
||||||
|
string(REPLACE "$" "$$" CLANG_TIDY_HEADER_FILTER "${PARSE_HEADER_FILTER}")
|
||||||
|
else()
|
||||||
|
set(CLANG_TIDY_HEADER_FILTER ".*")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(CLANG_TIDY_COMMAND
|
||||||
|
${CLANG_TIDY_EXE}
|
||||||
|
${CLANG_TIDY_QUIET_ARG}
|
||||||
|
-p ${CMAKE_BINARY_DIR}
|
||||||
|
-checks='${CLANG_TIDY_CHECKS}'
|
||||||
|
${CLANG_TIDY_ERRORS_ARG}
|
||||||
|
${CLANG_TIDY_EXTRA_ARGS}
|
||||||
|
${CLANG_TIDY_ANALYZE_TEMPORARY_DTORS}
|
||||||
|
-header-filter='${CLANG_TIDY_HEADER_FILTER}'
|
||||||
|
)
|
||||||
|
add_custom_target(tidy ${CLANG_TIDY_ALL})
|
||||||
|
mark_as_analyzer(tidy)
|
||||||
|
add_custom_target(tidy-base)
|
||||||
|
add_custom_target(tidy-make-fixit-dir COMMAND ${CMAKE_COMMAND} -E make_directory ${CLANG_TIDY_FIXIT_DIR})
|
||||||
|
add_custom_target(tidy-rm-fixit-dir COMMAND ${CMAKE_COMMAND} -E remove_directory ${CLANG_TIDY_FIXIT_DIR})
|
||||||
|
add_dependencies(tidy-make-fixit-dir tidy-rm-fixit-dir)
|
||||||
|
add_dependencies(tidy-base tidy-make-fixit-dir)
|
||||||
|
endmacro()
|
||||||
|
|
||||||
|
function(clang_tidy_check TARGET)
|
||||||
|
get_target_property(SOURCES ${TARGET} SOURCES)
|
||||||
|
# TODO: Use generator expressions instead
|
||||||
|
# COMMAND ${CLANG_TIDY_COMMAND} $<TARGET_PROPERTY:${TARGET},SOURCES>
|
||||||
|
# COMMAND ${CLANG_TIDY_COMMAND} $<JOIN:$<TARGET_PROPERTY:${TARGET},SOURCES>, >
|
||||||
|
foreach(SOURCE ${SOURCES})
|
||||||
|
if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS"))
|
||||||
|
string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file)
|
||||||
|
set(tidy_target tidy-target-${TARGET}-${tidy_file})
|
||||||
|
add_custom_target(${tidy_target}
|
||||||
|
# for some targets clang-tidy not able to get information from .clang-tidy
|
||||||
|
DEPENDS ${SOURCE}
|
||||||
|
COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml"
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..."
|
||||||
|
)
|
||||||
|
add_dependencies(${tidy_target} ${TARGET})
|
||||||
|
add_dependencies(${tidy_target} tidy-base)
|
||||||
|
add_dependencies(tidy ${tidy_target})
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
130
cmake/CppCheck.cmake
Normal file
130
cmake/CppCheck.cmake
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
################################################################################
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
#
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
include(CMakeParseArguments)
|
||||||
|
include(ProcessorCount)
|
||||||
|
include(Analyzers)
|
||||||
|
|
||||||
|
find_program(CPPCHECK_EXE
|
||||||
|
NAMES
|
||||||
|
cppcheck
|
||||||
|
PATHS
|
||||||
|
/opt/rocm/bin
|
||||||
|
)
|
||||||
|
|
||||||
|
ProcessorCount(CPPCHECK_JOBS)
|
||||||
|
|
||||||
|
set(CPPCHECK_BUILD_DIR ${CMAKE_BINARY_DIR}/cppcheck-build)
|
||||||
|
file(MAKE_DIRECTORY ${CPPCHECK_BUILD_DIR})
|
||||||
|
set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CPPCHECK_BUILD_DIR})
|
||||||
|
|
||||||
|
macro(enable_cppcheck)
|
||||||
|
set(options FORCE)
|
||||||
|
set(oneValueArgs)
|
||||||
|
set(multiValueArgs CHECKS SUPPRESS DEFINE UNDEFINE INCLUDE SOURCES)
|
||||||
|
|
||||||
|
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
string(REPLACE ";" "," CPPCHECK_CHECKS "${PARSE_CHECKS}")
|
||||||
|
string(REPLACE ";" "\n" CPPCHECK_SUPPRESS "${PARSE_SUPPRESS};*:/usr/*")
|
||||||
|
file(WRITE ${CMAKE_BINARY_DIR}/cppcheck-supressions "${CPPCHECK_SUPPRESS}")
|
||||||
|
set(CPPCHECK_DEFINES)
|
||||||
|
foreach(DEF ${PARSE_DEFINE})
|
||||||
|
set(CPPCHECK_DEFINES "${CPPCHECK_DEFINES} -D${DEF}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
set(CPPCHECK_UNDEFINES)
|
||||||
|
foreach(DEF ${PARSE_UNDEFINE})
|
||||||
|
set(CPPCHECK_UNDEFINES "${CPPCHECK_UNDEFINES} -U${DEF}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
set(CPPCHECK_INCLUDES)
|
||||||
|
foreach(INC ${PARSE_INCLUDE})
|
||||||
|
set(CPPCHECK_INCLUDES "${CPPCHECK_INCLUDES} -I${INC}")
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
# set(CPPCHECK_FORCE)
|
||||||
|
set(CPPCHECK_FORCE "--project=${CMAKE_BINARY_DIR}/compile_commands.json")
|
||||||
|
if(PARSE_FORCE)
|
||||||
|
set(CPPCHECK_FORCE --force)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(SOURCES)
|
||||||
|
set(GLOBS)
|
||||||
|
foreach(SOURCE ${PARSE_SOURCES})
|
||||||
|
get_filename_component(ABS_SOURCE ${SOURCE} ABSOLUTE)
|
||||||
|
if(EXISTS ${ABS_SOURCE})
|
||||||
|
if(IS_DIRECTORY ${ABS_SOURCE})
|
||||||
|
set(GLOBS "${GLOBS} ${ABS_SOURCE}/*.cpp ${ABS_SOURCE}/*.hpp ${ABS_SOURCE}/*.cxx ${ABS_SOURCE}/*.c ${ABS_SOURCE}/*.h")
|
||||||
|
else()
|
||||||
|
set(SOURCES "${SOURCES} ${ABS_SOURCE}")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
set(GLOBS "${GLOBS} ${ABS_SOURCE}")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
file(WRITE ${CMAKE_BINARY_DIR}/cppcheck.cmake "
|
||||||
|
file(GLOB_RECURSE GSRCS ${GLOBS})
|
||||||
|
set(CPPCHECK_COMMAND
|
||||||
|
${CPPCHECK_EXE}
|
||||||
|
-q
|
||||||
|
# -v
|
||||||
|
# --report-progress
|
||||||
|
${CPPCHECK_FORCE}
|
||||||
|
--cppcheck-build-dir=${CPPCHECK_BUILD_DIR}
|
||||||
|
--platform=native
|
||||||
|
--template=gcc
|
||||||
|
--error-exitcode=1
|
||||||
|
-j ${CPPCHECK_JOBS}
|
||||||
|
${CPPCHECK_DEFINES}
|
||||||
|
${CPPCHECK_UNDEFINES}
|
||||||
|
${CPPCHECK_INCLUDES}
|
||||||
|
--enable=${CPPCHECK_CHECKS}
|
||||||
|
--inline-suppr
|
||||||
|
--suppressions-list=${CMAKE_BINARY_DIR}/cppcheck-supressions
|
||||||
|
${SOURCES} \${GSRCS}
|
||||||
|
)
|
||||||
|
string(REPLACE \";\" \" \" CPPCHECK_SHOW_COMMAND \"\${CPPCHECK_COMMAND}\")
|
||||||
|
message(\"\${CPPCHECK_SHOW_COMMAND}\")
|
||||||
|
execute_process(
|
||||||
|
COMMAND \${CPPCHECK_COMMAND}
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
RESULT_VARIABLE RESULT
|
||||||
|
)
|
||||||
|
if(NOT RESULT EQUAL 0)
|
||||||
|
message(FATAL_ERROR \"Cppcheck failed\")
|
||||||
|
endif()
|
||||||
|
")
|
||||||
|
|
||||||
|
add_custom_target(cppcheck
|
||||||
|
COMMAND ${CMAKE_COMMAND} -P ${CMAKE_BINARY_DIR}/cppcheck.cmake
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
COMMENT "cppcheck: Running cppcheck..."
|
||||||
|
)
|
||||||
|
mark_as_analyzer(cppcheck)
|
||||||
|
endmacro()
|
||||||
|
|
||||||
|
|
||||||
355
cmake/DoxygenDoc.cmake
Normal file
355
cmake/DoxygenDoc.cmake
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
################################################################################
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
#
|
||||||
|
################################################################################
|
||||||
|
include(CMakeParseArguments)
|
||||||
|
include(MainDoc)
|
||||||
|
|
||||||
|
find_program(DOXYGEN_EXECUTABLE NAMES doxygen
|
||||||
|
PATH_SUFFIXES bin
|
||||||
|
DOC "Doxygen documentation generator"
|
||||||
|
)
|
||||||
|
mark_as_advanced(DOXYGEN_EXECUTABLE)
|
||||||
|
|
||||||
|
find_path(DOT_EXECUTABLE NAMES dot
|
||||||
|
PATH_SUFFIXES bin
|
||||||
|
DOC "Graphviz"
|
||||||
|
)
|
||||||
|
mark_as_advanced(DOT_EXECUTABLE)
|
||||||
|
|
||||||
|
set(DOXYGEN_ARGS
|
||||||
|
ABBREVIATE_BRIEF
|
||||||
|
ALIASES
|
||||||
|
ALLEXTERNALS
|
||||||
|
ALLOW_UNICODE_NAMES
|
||||||
|
ALPHABETICAL_INDEX
|
||||||
|
ALWAYS_DETAILED_SEC
|
||||||
|
AUTOLINK_SUPPORT
|
||||||
|
BINARY_TOC
|
||||||
|
BRIEF_MEMBER_DESC
|
||||||
|
BUILTIN_STL_SUPPORT
|
||||||
|
CALLER_GRAPH
|
||||||
|
CALL_GRAPH
|
||||||
|
CASE_SENSE_NAMES
|
||||||
|
CHM_FILE
|
||||||
|
CHM_INDEX_ENCODING
|
||||||
|
CITE_BIB_FILES
|
||||||
|
CLANG_ASSISTED_PARSING
|
||||||
|
CLANG_OPTIONS
|
||||||
|
CLASS_DIAGRAMS
|
||||||
|
CLASS_GRAPH
|
||||||
|
COLLABORATION_GRAPH
|
||||||
|
COLS_IN_ALPHA_INDEX
|
||||||
|
COMPACT_LATEX
|
||||||
|
COMPACT_RTF
|
||||||
|
CPP_CLI_SUPPORT
|
||||||
|
CREATE_SUBDIRS
|
||||||
|
DIAFILE_DIRS
|
||||||
|
DIA_PATH
|
||||||
|
DIRECTORY_GRAPH
|
||||||
|
DISABLE_INDEX
|
||||||
|
DISTRIBUTE_GROUP_DOC
|
||||||
|
DOCBOOK_OUTPUT
|
||||||
|
DOCBOOK_PROGRAMLISTING
|
||||||
|
DOCSET_BUNDLE_ID
|
||||||
|
DOCSET_FEEDNAME
|
||||||
|
DOCSET_PUBLISHER_ID
|
||||||
|
DOCSET_PUBLISHER_NAME
|
||||||
|
DOTFILE_DIRS
|
||||||
|
DOT_CLEANUP
|
||||||
|
DOT_FONTNAME
|
||||||
|
DOT_FONTPATH
|
||||||
|
DOT_FONTSIZE
|
||||||
|
DOT_GRAPH_MAX_NODES
|
||||||
|
DOT_IMAGE_FORMAT
|
||||||
|
DOT_MULTI_TARGETS
|
||||||
|
DOT_NUM_THREADS
|
||||||
|
# DOT_PATH
|
||||||
|
DOT_TRANSPARENT
|
||||||
|
DOXYFILE_ENCODING
|
||||||
|
ECLIPSE_DOC_ID
|
||||||
|
ENABLED_SECTIONS
|
||||||
|
ENABLE_PREPROCESSING
|
||||||
|
ENUM_VALUES_PER_LINE
|
||||||
|
EXAMPLE_PATH
|
||||||
|
EXAMPLE_PATTERNS
|
||||||
|
EXAMPLE_RECURSIVE
|
||||||
|
EXCLUDE
|
||||||
|
EXCLUDE_PATTERNS
|
||||||
|
EXCLUDE_SYMBOLS
|
||||||
|
EXCLUDE_SYMLINKS
|
||||||
|
EXPAND_AS_DEFINED
|
||||||
|
EXPAND_ONLY_PREDEF
|
||||||
|
EXTENSION_MAPPING
|
||||||
|
EXTERNAL_GROUPS
|
||||||
|
EXTERNAL_PAGES
|
||||||
|
EXTERNAL_SEARCH
|
||||||
|
EXTERNAL_SEARCH_ID
|
||||||
|
EXTRACT_ALL
|
||||||
|
EXTRACT_ANON_NSPACES
|
||||||
|
EXTRACT_LOCAL_CLASSES
|
||||||
|
EXTRACT_LOCAL_METHODS
|
||||||
|
EXTRACT_PACKAGE
|
||||||
|
EXTRACT_PRIVATE
|
||||||
|
EXTRACT_STATIC
|
||||||
|
EXTRA_PACKAGES
|
||||||
|
EXTRA_SEARCH_MAPPINGS
|
||||||
|
EXT_LINKS_IN_WINDOW
|
||||||
|
FILE_PATTERNS
|
||||||
|
FILE_VERSION_FILTER
|
||||||
|
FILTER_PATTERNS
|
||||||
|
FILTER_SOURCE_FILES
|
||||||
|
FILTER_SOURCE_PATTERNS
|
||||||
|
FORCE_LOCAL_INCLUDES
|
||||||
|
FORMULA_FONTSIZE
|
||||||
|
FORMULA_TRANSPARENT
|
||||||
|
FULL_PATH_NAMES
|
||||||
|
GENERATE_AUTOGEN_DEF
|
||||||
|
GENERATE_BUGLIST
|
||||||
|
GENERATE_CHI
|
||||||
|
GENERATE_DEPRECATEDLIST
|
||||||
|
GENERATE_DOCBOOK
|
||||||
|
GENERATE_DOCSET
|
||||||
|
GENERATE_ECLIPSEHELP
|
||||||
|
GENERATE_HTML
|
||||||
|
GENERATE_HTMLHELP
|
||||||
|
GENERATE_LATEX
|
||||||
|
GENERATE_LEGEND
|
||||||
|
GENERATE_MAN
|
||||||
|
GENERATE_PERLMOD
|
||||||
|
GENERATE_QHP
|
||||||
|
GENERATE_RTF
|
||||||
|
GENERATE_TAGFILE
|
||||||
|
GENERATE_TESTLIST
|
||||||
|
GENERATE_TODOLIST
|
||||||
|
GENERATE_TREEVIEW
|
||||||
|
GENERATE_XML
|
||||||
|
GRAPHICAL_HIERARCHY
|
||||||
|
GROUP_GRAPHS
|
||||||
|
GROUP_NESTED_COMPOUNDS
|
||||||
|
# HAVE_DOT
|
||||||
|
HHC_LOCATION
|
||||||
|
HIDE_COMPOUND_REFERENCE
|
||||||
|
HIDE_FRIEND_COMPOUNDS
|
||||||
|
HIDE_IN_BODY_DOCS
|
||||||
|
HIDE_SCOPE_NAMES
|
||||||
|
HIDE_UNDOC_CLASSES
|
||||||
|
HIDE_UNDOC_MEMBERS
|
||||||
|
HIDE_UNDOC_RELATIONS
|
||||||
|
HTML_COLORSTYLE_GAMMA
|
||||||
|
HTML_COLORSTYLE_HUE
|
||||||
|
HTML_COLORSTYLE_SAT
|
||||||
|
HTML_DYNAMIC_SECTIONS
|
||||||
|
HTML_EXTRA_FILES
|
||||||
|
HTML_EXTRA_STYLESHEET
|
||||||
|
HTML_FILE_EXTENSION
|
||||||
|
HTML_FOOTER
|
||||||
|
HTML_HEADER
|
||||||
|
HTML_INDEX_NUM_ENTRIES
|
||||||
|
HTML_OUTPUT
|
||||||
|
HTML_STYLESHEET
|
||||||
|
HTML_TIMESTAMP
|
||||||
|
IDL_PROPERTY_SUPPORT
|
||||||
|
IGNORE_PREFIX
|
||||||
|
IMAGE_PATH
|
||||||
|
INCLUDED_BY_GRAPH
|
||||||
|
INCLUDE_FILE_PATTERNS
|
||||||
|
INCLUDE_GRAPH
|
||||||
|
INCLUDE_PATH
|
||||||
|
INHERIT_DOCS
|
||||||
|
INLINE_GROUPED_CLASSES
|
||||||
|
INLINE_INFO
|
||||||
|
INLINE_INHERITED_MEMB
|
||||||
|
INLINE_SIMPLE_STRUCTS
|
||||||
|
INLINE_SOURCES
|
||||||
|
INPUT
|
||||||
|
INPUT_ENCODING
|
||||||
|
INPUT_FILTER
|
||||||
|
INTERACTIVE_SVG
|
||||||
|
INTERNAL_DOCS
|
||||||
|
JAVADOC_AUTOBRIEF
|
||||||
|
LATEX_BATCHMODE
|
||||||
|
LATEX_BIB_STYLE
|
||||||
|
LATEX_CMD_NAME
|
||||||
|
LATEX_EXTRA_FILES
|
||||||
|
LATEX_EXTRA_STYLESHEET
|
||||||
|
LATEX_FOOTER
|
||||||
|
LATEX_HEADER
|
||||||
|
LATEX_HIDE_INDICES
|
||||||
|
LATEX_OUTPUT
|
||||||
|
LATEX_SOURCE_CODE
|
||||||
|
LATEX_TIMESTAMP
|
||||||
|
LAYOUT_FILE
|
||||||
|
LOOKUP_CACHE_SIZE
|
||||||
|
MACRO_EXPANSION
|
||||||
|
MAKEINDEX_CMD_NAME
|
||||||
|
MAN_EXTENSION
|
||||||
|
MAN_LINKS
|
||||||
|
MAN_OUTPUT
|
||||||
|
MAN_SUBDIR
|
||||||
|
MARKDOWN_SUPPORT
|
||||||
|
MATHJAX_CODEFILE
|
||||||
|
MATHJAX_EXTENSIONS
|
||||||
|
MATHJAX_FORMAT
|
||||||
|
MATHJAX_RELPATH
|
||||||
|
MAX_DOT_GRAPH_DEPTH
|
||||||
|
MAX_INITIALIZER_LINES
|
||||||
|
MSCFILE_DIRS
|
||||||
|
MSCGEN_PATH
|
||||||
|
MULTILINE_CPP_IS_BRIEF
|
||||||
|
OPTIMIZE_FOR_FORTRAN
|
||||||
|
OPTIMIZE_OUTPUT_FOR_C
|
||||||
|
OPTIMIZE_OUTPUT_JAVA
|
||||||
|
OPTIMIZE_OUTPUT_VHDL
|
||||||
|
OUTPUT_DIRECTORY
|
||||||
|
OUTPUT_LANGUAGE
|
||||||
|
PAPER_TYPE
|
||||||
|
PDF_HYPERLINKS
|
||||||
|
PERLMOD_LATEX
|
||||||
|
PERLMOD_MAKEVAR_PREFIX
|
||||||
|
PERLMOD_PRETTY
|
||||||
|
PERL_PATH
|
||||||
|
PLANTUML_CFG_FILE
|
||||||
|
PLANTUML_INCLUDE_PATH
|
||||||
|
PLANTUML_JAR_PATH
|
||||||
|
PREDEFINED
|
||||||
|
PROJECT_BRIEF
|
||||||
|
PROJECT_LOGO
|
||||||
|
PROJECT_NAME
|
||||||
|
PROJECT_NUMBER
|
||||||
|
QCH_FILE
|
||||||
|
QHG_LOCATION
|
||||||
|
QHP_CUST_FILTER_ATTRS
|
||||||
|
QHP_CUST_FILTER_NAME
|
||||||
|
QHP_NAMESPACE
|
||||||
|
QHP_SECT_FILTER_ATTRS
|
||||||
|
QHP_VIRTUAL_FOLDER
|
||||||
|
QT_AUTOBRIEF
|
||||||
|
QUIET
|
||||||
|
RECURSIVE
|
||||||
|
REFERENCED_BY_RELATION
|
||||||
|
REFERENCES_LINK_SOURCE
|
||||||
|
REFERENCES_RELATION
|
||||||
|
REPEAT_BRIEF
|
||||||
|
RTF_EXTENSIONS_FILE
|
||||||
|
RTF_HYPERLINKS
|
||||||
|
RTF_OUTPUT
|
||||||
|
RTF_SOURCE_CODE
|
||||||
|
RTF_STYLESHEET_FILE
|
||||||
|
SEARCHDATA_FILE
|
||||||
|
SEARCHENGINE
|
||||||
|
SEARCHENGINE_URL
|
||||||
|
SEARCH_INCLUDES
|
||||||
|
SEPARATE_MEMBER_PAGES
|
||||||
|
SERVER_BASED_SEARCH
|
||||||
|
SHORT_NAMES
|
||||||
|
SHOW_FILES
|
||||||
|
SHOW_GROUPED_MEMB_INC
|
||||||
|
SHOW_INCLUDE_FILES
|
||||||
|
SHOW_NAMESPACES
|
||||||
|
SHOW_USED_FILES
|
||||||
|
SIP_SUPPORT
|
||||||
|
SKIP_FUNCTION_MACROS
|
||||||
|
SORT_BRIEF_DOCS
|
||||||
|
SORT_BY_SCOPE_NAME
|
||||||
|
SORT_GROUP_NAMES
|
||||||
|
SORT_MEMBERS_CTORS_1ST
|
||||||
|
SORT_MEMBER_DOCS
|
||||||
|
SOURCE_BROWSER
|
||||||
|
SOURCE_TOOLTIPS
|
||||||
|
STRICT_PROTO_MATCHING
|
||||||
|
STRIP_CODE_COMMENTS
|
||||||
|
STRIP_FROM_INC_PATH
|
||||||
|
STRIP_FROM_PATH
|
||||||
|
SUBGROUPING
|
||||||
|
TAB_SIZE
|
||||||
|
TAGFILES
|
||||||
|
TCL_SUBST
|
||||||
|
TEMPLATE_RELATIONS
|
||||||
|
TOC_EXPAND
|
||||||
|
TOC_INCLUDE_HEADINGS
|
||||||
|
TREEVIEW_WIDTH
|
||||||
|
TYPEDEF_HIDES_STRUCT
|
||||||
|
UML_LIMIT_NUM_FIELDS
|
||||||
|
UML_LOOK
|
||||||
|
USE_HTAGS
|
||||||
|
USE_MATHJAX
|
||||||
|
USE_MDFILE_AS_MAINPAGE
|
||||||
|
USE_PDFLATEX
|
||||||
|
VERBATIM_HEADERS
|
||||||
|
WARNINGS
|
||||||
|
WARN_AS_ERROR
|
||||||
|
WARN_FORMAT
|
||||||
|
WARN_IF_DOC_ERROR
|
||||||
|
WARN_IF_UNDOCUMENTED
|
||||||
|
WARN_LOGFILE
|
||||||
|
WARN_NO_PARAMDOC
|
||||||
|
XML_OUTPUT
|
||||||
|
XML_PROGRAMLISTING
|
||||||
|
)
|
||||||
|
|
||||||
|
set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file")
|
||||||
|
|
||||||
|
function(add_doxygen_doc)
|
||||||
|
set(options)
|
||||||
|
set(oneValueArgs)
|
||||||
|
set(multiValueArgs DEPENDS ${DOXYGEN_ARGS})
|
||||||
|
|
||||||
|
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||||
|
|
||||||
|
file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n")
|
||||||
|
|
||||||
|
foreach(ARG ${DOXYGEN_ARGS})
|
||||||
|
if(PARSE_${ARG})
|
||||||
|
string(REPLACE ";" " " ARG_VALUE ${PARSE_${ARG}})
|
||||||
|
file(APPEND ${DOXYGEN_CONFIG_FILE} "\n${ARG} = ${ARG_VALUE}\n")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
if(PARSE_OUTPUT_DIRECTORY)
|
||||||
|
if(NOT EXISTS ${PARSE_OUTPUT_DIRECTORY})
|
||||||
|
file(MAKE_DIRECTORY ${PARSE_OUTPUT_DIRECTORY})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(DOT_EXECUTABLE)
|
||||||
|
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nDOT_PATH = \"${DOT_EXECUTABLE}\"\n")
|
||||||
|
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = YES\n")
|
||||||
|
else()
|
||||||
|
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = NO\n")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_custom_target(doxygen
|
||||||
|
${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE}
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
COMMENT "Building documentation with doxygen"
|
||||||
|
)
|
||||||
|
if(PARSE_OUTPUT_DIRECTORY)
|
||||||
|
clean_doc_output(${PARSE_OUTPUT_DIRECTORY})
|
||||||
|
endif()
|
||||||
|
mark_as_doc(doxygen)
|
||||||
|
if(PARSE_DEPENDS)
|
||||||
|
add_dependencies(doxygen ${PARSE_DEPENDS})
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
110
cmake/EnableCompilerWarnings.cmake
Normal file
110
cmake/EnableCompilerWarnings.cmake
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
################################################################################
|
||||||
|
#
|
||||||
|
# MIT License
|
||||||
|
#
|
||||||
|
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
# of this software and associated documentation files (the "Software"), to deal
|
||||||
|
# in the Software without restriction, including without limitation the rights
|
||||||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
# copies of the Software, and to permit persons to whom the Software is
|
||||||
|
# furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in all
|
||||||
|
# copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
# SOFTWARE.
|
||||||
|
#
|
||||||
|
################################################################################
|
||||||
|
# - Enable warning all for gcc/clang or use /W4 for visual studio
|
||||||
|
|
||||||
|
## Strict warning level
|
||||||
|
if (MSVC)
|
||||||
|
# Use the highest warning level for visual studio.
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w")
|
||||||
|
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w")
|
||||||
|
# set(CMAKE_CXX_WARNING_LEVEL 4)
|
||||||
|
# if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]")
|
||||||
|
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||||
|
# else ()
|
||||||
|
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
|
||||||
|
# endif ()
|
||||||
|
|
||||||
|
# set(CMAKE_C_WARNING_LEVEL 4)
|
||||||
|
# if (CMAKE_C_FLAGS MATCHES "/W[0-4]")
|
||||||
|
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||||
|
# else ()
|
||||||
|
# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4")
|
||||||
|
# endif ()
|
||||||
|
|
||||||
|
else()
|
||||||
|
foreach(COMPILER C CXX)
|
||||||
|
set(CMAKE_COMPILER_WARNINGS)
|
||||||
|
# use -Wall for gcc and clang
|
||||||
|
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||||
|
-Wall
|
||||||
|
-Wextra
|
||||||
|
-Wcomment
|
||||||
|
-Wendif-labels
|
||||||
|
-Wformat
|
||||||
|
-Winit-self
|
||||||
|
-Wreturn-type
|
||||||
|
-Wsequence-point
|
||||||
|
# Shadow is broken on gcc when using lambdas
|
||||||
|
# -Wshadow
|
||||||
|
-Wswitch
|
||||||
|
-Wtrigraphs
|
||||||
|
-Wundef
|
||||||
|
-Wuninitialized
|
||||||
|
-Wunreachable-code
|
||||||
|
-Wunused
|
||||||
|
|
||||||
|
-Wno-sign-compare
|
||||||
|
-Wno-extra-semi-stmt
|
||||||
|
)
|
||||||
|
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
|
||||||
|
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||||
|
-Weverything
|
||||||
|
-Wno-c++98-compat
|
||||||
|
-Wno-c++98-compat-pedantic
|
||||||
|
-Wno-conversion
|
||||||
|
-Wno-double-promotion
|
||||||
|
-Wno-exit-time-destructors
|
||||||
|
-Wno-extra-semi
|
||||||
|
-Wno-float-conversion
|
||||||
|
-Wno-gnu-anonymous-struct
|
||||||
|
-Wno-gnu-zero-variadic-macro-arguments
|
||||||
|
-Wno-missing-prototypes
|
||||||
|
-Wno-nested-anon-types
|
||||||
|
-Wno-padded
|
||||||
|
-Wno-return-std-move-in-c++11
|
||||||
|
-Wno-shorten-64-to-32
|
||||||
|
-Wno-sign-conversion
|
||||||
|
-Wno-unknown-warning-option
|
||||||
|
-Wno-unused-command-line-argument
|
||||||
|
-Wno-weak-vtables
|
||||||
|
-Wno-covered-switch-default
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")
|
||||||
|
# cmake 3.5.2 does not support >=.
|
||||||
|
if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.1")
|
||||||
|
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||||
|
-Wno-ignored-attributes)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||||
|
-Wno-missing-field-initializers
|
||||||
|
-Wno-deprecated-declarations
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
add_definitions(${CMAKE_COMPILER_WARNINGS})
|
||||||
|
endforeach()
|
||||||
|
endif ()
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
|
|
||||||
function(get_target_property2 VAR TARGET PROPERTY)
|
|
||||||
get_target_property(_pflags ${TARGET} ${PROPERTY})
|
|
||||||
if(_pflags)
|
|
||||||
set(${VAR} ${_pflags} PARENT_SCOPE)
|
|
||||||
else()
|
|
||||||
set(${VAR} "" PARENT_SCOPE)
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
|
|
||||||
macro(append_flags FLAGS TARGET PROPERTY PREFIX)
|
|
||||||
get_target_property2(_pflags ${TARGET} ${PROPERTY})
|
|
||||||
foreach(FLAG ${_pflags})
|
|
||||||
if(TARGET ${FLAG})
|
|
||||||
target_flags(_pflags2 ${FLAG})
|
|
||||||
string(APPEND ${FLAGS} " ${_pflags2}")
|
|
||||||
else()
|
|
||||||
string(APPEND ${FLAGS} " ${PREFIX}${FLAG}")
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
endmacro()
|
|
||||||
|
|
||||||
macro(append_link_flags FLAGS TARGET PROPERTY)
|
|
||||||
get_target_property2(_pflags ${TARGET} ${PROPERTY})
|
|
||||||
foreach(FLAG ${_pflags})
|
|
||||||
if(TARGET ${FLAG})
|
|
||||||
target_flags(_pflags2 ${FLAG})
|
|
||||||
string(APPEND ${FLAGS} " ${_pflags2}")
|
|
||||||
elseif(FLAG MATCHES "^-.*")
|
|
||||||
string(APPEND ${FLAGS} " ${FLAG}")
|
|
||||||
elseif(EXISTS ${FLAG})
|
|
||||||
string(APPEND ${FLAGS} " ${FLAG}")
|
|
||||||
else()
|
|
||||||
string(APPEND ${FLAGS} " -l${FLAG}")
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
endmacro()
|
|
||||||
|
|
||||||
function(target_flags FLAGS TARGET)
|
|
||||||
set(_flags)
|
|
||||||
append_flags(_flags ${TARGET} "INTERFACE_COMPILE_OPTIONS" "")
|
|
||||||
append_flags(_flags ${TARGET} "INTERFACE_COMPILE_DEFINITIONS" "-D")
|
|
||||||
append_flags(_flags ${TARGET} "INTERFACE_INCLUDE_DIRECTORIES" "-isystem ")
|
|
||||||
append_flags(_flags ${TARGET} "INTERFACE_LINK_DIRECTORIES" "-L ")
|
|
||||||
append_flags(_flags ${TARGET} "INTERFACE_LINK_OPTIONS" "")
|
|
||||||
append_link_flags(_flags ${TARGET} "INTERFACE_LINK_LIBRARIES" "")
|
|
||||||
# message("_flags: ${_flags}")
|
|
||||||
set(${FLAGS} ${_flags} PARENT_SCOPE)
|
|
||||||
endfunction()
|
|
||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
|
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -23,9 +23,9 @@ template <typename... Wei,
|
|||||||
index_t GemmK1Value>
|
index_t GemmK1Value>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -102,7 +102,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
const auto K0 = K / K1;
|
const auto K0 = K / K1;
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||||
wei_k_y_x_c_grid_desc,
|
wei_k_y_x_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(K),
|
make_tuple(make_pass_through_transform(K),
|
||||||
make_embed_transform(make_tuple(YDot, YTilda),
|
make_embed_transform(make_tuple(YDot, YTilda),
|
||||||
@@ -114,28 +114,28 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||||
transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||||
make_slice_transform(YDot, I0, YDotSlice),
|
make_slice_transform(YDot, I0, YDotSlice),
|
||||||
make_slice_transform(XDot, I0, XDotSlice),
|
make_slice_transform(XDot, I0, XDotSlice),
|
||||||
make_freeze_transform(IYTilda),
|
make_freeze_transform(IYTilda),
|
||||||
make_freeze_transform(IXTilda),
|
make_freeze_transform(IXTilda),
|
||||||
make_pass_through_transform(C)),
|
make_pass_through_transform(C)),
|
||||||
make_tuple(Sequence<0>{},
|
make_tuple(Sequence<0>{},
|
||||||
Sequence<1>{},
|
Sequence<1>{},
|
||||||
Sequence<3>{},
|
Sequence<3>{},
|
||||||
Sequence<2>{},
|
Sequence<2>{},
|
||||||
Sequence<4>{},
|
Sequence<4>{},
|
||||||
Sequence<5>{}),
|
Sequence<5>{}),
|
||||||
make_tuple(Sequence<0, 1>{},
|
make_tuple(Sequence<0, 1>{},
|
||||||
Sequence<2>{},
|
Sequence<2>{},
|
||||||
Sequence<3>{},
|
Sequence<3>{},
|
||||||
Sequence<>{},
|
Sequence<>{},
|
||||||
Sequence<>{},
|
Sequence<>{},
|
||||||
Sequence<4>{}));
|
Sequence<4>{}));
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -143,7 +143,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||||
#else
|
#else
|
||||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -154,7 +154,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
// this add padding check
|
// this add padding check
|
||||||
const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||||
out_n_ho_wo_k_grid_desc,
|
out_n_ho_wo_k_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pad_transform(Ho, I0, I0),
|
make_pad_transform(Ho, I0, I0),
|
||||||
@@ -163,7 +163,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||||
out_n_hop_wop_k_grid_desc,
|
out_n_hop_wop_k_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_embed_transform(make_tuple(YDot, HTilda),
|
make_embed_transform(make_tuple(YDot, HTilda),
|
||||||
@@ -175,7 +175,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||||
transform_dynamic_tensor_descriptor(
|
transform_tensor_descriptor(
|
||||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_slice_transform(YDot, I0, YDotSlice),
|
make_slice_transform(YDot, I0, YDotSlice),
|
||||||
@@ -197,7 +197,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
Sequence<5, 6>{}));
|
Sequence<5, 6>{}));
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||||
@@ -205,7 +205,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||||
#else
|
#else
|
||||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||||
@@ -215,7 +215,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hi_wi_c_grid_desc,
|
in_n_hi_wi_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||||
@@ -224,7 +224,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hip_wip_c_grid_desc,
|
in_n_hip_wip_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||||
@@ -235,7 +235,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_freeze_transform(IYTilda),
|
make_freeze_transform(IYTilda),
|
||||||
@@ -256,7 +256,7 @@ transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
|||||||
Sequence<2>{},
|
Sequence<2>{},
|
||||||
Sequence<3>{}));
|
Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(C),
|
make_tuple(make_pass_through_transform(C),
|
||||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
|
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
|
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -26,9 +26,9 @@ template <typename... Wei,
|
|||||||
index_t GemmK1Value>
|
index_t GemmK1Value>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -106,7 +106,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
|
|
||||||
// A: output tensor
|
// A: output tensor
|
||||||
// this add padding check
|
// this add padding check
|
||||||
const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||||
out_n_ho_wo_k_grid_desc,
|
out_n_ho_wo_k_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pad_transform(Ho, I0, I0),
|
make_pad_transform(Ho, I0, I0),
|
||||||
@@ -115,7 +115,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||||
out_n_hop_wop_k_grid_desc,
|
out_n_hop_wop_k_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_embed_transform(make_tuple(YDot, HTilda),
|
make_embed_transform(make_tuple(YDot, HTilda),
|
||||||
@@ -127,7 +127,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||||
transform_dynamic_tensor_descriptor(
|
transform_tensor_descriptor(
|
||||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_slice_transform(YDot, I0, YDotSlice),
|
make_slice_transform(YDot, I0, YDotSlice),
|
||||||
@@ -149,7 +149,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
Sequence<5, 6>{}));
|
Sequence<5, 6>{}));
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||||
@@ -157,7 +157,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||||
#else
|
#else
|
||||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||||
@@ -167,7 +167,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// B: weight tensor
|
// B: weight tensor
|
||||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||||
wei_k_y_x_c_grid_desc,
|
wei_k_y_x_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(K),
|
make_tuple(make_pass_through_transform(K),
|
||||||
make_embed_transform(make_tuple(YDot, YTilda),
|
make_embed_transform(make_tuple(YDot, YTilda),
|
||||||
@@ -179,28 +179,28 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||||
transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||||
make_slice_transform(YDot, I0, YDotSlice),
|
make_slice_transform(YDot, I0, YDotSlice),
|
||||||
make_slice_transform(XDot, I0, XDotSlice),
|
make_slice_transform(XDot, I0, XDotSlice),
|
||||||
make_freeze_transform(IYTilda),
|
make_freeze_transform(IYTilda),
|
||||||
make_freeze_transform(IXTilda),
|
make_freeze_transform(IXTilda),
|
||||||
make_pass_through_transform(C)),
|
make_pass_through_transform(C)),
|
||||||
make_tuple(Sequence<0>{},
|
make_tuple(Sequence<0>{},
|
||||||
Sequence<1>{},
|
Sequence<1>{},
|
||||||
Sequence<3>{},
|
Sequence<3>{},
|
||||||
Sequence<2>{},
|
Sequence<2>{},
|
||||||
Sequence<4>{},
|
Sequence<4>{},
|
||||||
Sequence<5>{}),
|
Sequence<5>{}),
|
||||||
make_tuple(Sequence<0, 1>{},
|
make_tuple(Sequence<0, 1>{},
|
||||||
Sequence<2>{},
|
Sequence<2>{},
|
||||||
Sequence<3>{},
|
Sequence<3>{},
|
||||||
Sequence<>{},
|
Sequence<>{},
|
||||||
Sequence<>{},
|
Sequence<>{},
|
||||||
Sequence<4>{}));
|
Sequence<4>{}));
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -208,7 +208,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||||
#else
|
#else
|
||||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -218,7 +218,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
// C: input tensor
|
// C: input tensor
|
||||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hi_wi_c_grid_desc,
|
in_n_hi_wi_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||||
@@ -227,7 +227,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hip_wip_c_grid_desc,
|
in_n_hip_wip_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||||
@@ -238,7 +238,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_freeze_transform(IYTilda),
|
make_freeze_transform(IYTilda),
|
||||||
@@ -259,7 +259,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
|||||||
Sequence<2>{},
|
Sequence<2>{},
|
||||||
Sequence<3>{}));
|
Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||||
make_pass_through_transform(C)),
|
make_pass_through_transform(C)),
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -18,9 +18,9 @@ template <typename... Wei,
|
|||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
const auto InRightPadW = in_right_pads[I1];
|
const auto InRightPadW = in_right_pads[I1];
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hi_wi_global_desc,
|
in_n_c_hi_wi_global_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hip_wip_global_desc,
|
in_n_c_hip_wip_global_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||||
|
|
||||||
const auto in_gemmk_gemmn_global_desc =
|
const auto in_gemmk_gemmn_global_desc =
|
||||||
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
@@ -109,9 +109,9 @@ template <typename... Wei,
|
|||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -126,9 +126,6 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
|||||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||||
|
|
||||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
|
||||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
|
||||||
|
|
||||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||||
|
|
||||||
@@ -150,14 +147,14 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
|||||||
assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
|
assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hi_wi_global_desc,
|
in_n_c_hi_wi_global_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -167,15 +164,15 @@ transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||||
|
|
||||||
const auto in_gemmk_gemmn_global_desc =
|
const auto in_gemmk_gemmn_global_desc =
|
||||||
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
@@ -192,9 +189,9 @@ template <typename... Wei,
|
|||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
|
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -209,9 +206,6 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||||
|
|
||||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
|
||||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
|
||||||
|
|
||||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||||
|
|
||||||
@@ -235,22 +229,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
InRightPadW == 0);
|
InRightPadW == 0);
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hi_wi_global_desc,
|
in_n_c_hi_wi_global_desc,
|
||||||
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
|
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -18,9 +18,9 @@ template <typename... Wei,
|
|||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
|
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -57,14 +57,14 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
const auto InRightPadW = in_right_pads[I1];
|
const auto InRightPadW = in_right_pads[I1];
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hi_wi_c_grid_desc,
|
in_n_hi_wi_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||||
@@ -73,7 +73,7 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hip_wip_c_grid_desc,
|
in_n_hip_wip_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||||
@@ -83,15 +83,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto in_gemmk_gemmn_grid_desc =
|
const auto in_gemmk_gemmn_grid_desc =
|
||||||
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
@@ -108,9 +108,9 @@ template <typename... Wei,
|
|||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
|
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -125,9 +125,6 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||||
|
|
||||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
|
||||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
|
||||||
|
|
||||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||||
|
|
||||||
@@ -151,22 +148,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_
|
|||||||
InRightPadW == 0);
|
InRightPadW == 0);
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)),
|
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)),
|
||||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
|
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -20,9 +20,9 @@ template <typename... Wei,
|
|||||||
index_t GemmK1Value>
|
index_t GemmK1Value>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
|||||||
const auto GemmK0 = GemmK / GemmK1;
|
const auto GemmK0 = GemmK / GemmK1;
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
|
||||||
wei_gemmk_gemmm_grid_desc,
|
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||||
make_pass_through_transform(GemmM)),
|
make_pass_through_transform(GemmM)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hi_wi_grid_desc,
|
in_n_c_hi_wi_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hip_wip_grid_desc,
|
in_n_c_hip_wip_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||||
|
|
||||||
const auto in_gemmk_gemmn_grid_desc =
|
const auto in_gemmk_gemmn_grid_desc =
|
||||||
transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|
||||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||||
in_gemmk_gemmn_grid_desc,
|
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||||
make_pass_through_transform(GemmN)),
|
make_pass_through_transform(GemmN)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)),
|
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -20,9 +20,9 @@ template <typename... Wei,
|
|||||||
index_t GemmK1Value>
|
index_t GemmK1Value>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -67,21 +67,21 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
|||||||
const auto GemmK0 = GemmK / GemmK1;
|
const auto GemmK0 = GemmK / GemmK1;
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
|
||||||
wei_gemmk_gemmm_grid_desc,
|
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||||
make_pass_through_transform(GemmM)),
|
make_pass_through_transform(GemmM)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hi_wi_c_grid_desc,
|
in_n_hi_wi_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||||
@@ -90,7 +90,7 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hip_wip_c_grid_desc,
|
in_n_hip_wip_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||||
@@ -100,22 +100,22 @@ transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto in_gemmk_gemmn_grid_desc =
|
const auto in_gemmk_gemmn_grid_desc =
|
||||||
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|
||||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||||
in_gemmk_gemmn_grid_desc,
|
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||||
make_pass_through_transform(GemmN)),
|
make_pass_through_transform(GemmN)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -23,9 +23,9 @@ template <typename... In,
|
|||||||
index_t GemmK1Value>
|
index_t GemmK1Value>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||||
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -70,7 +70,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
|||||||
const auto GemmK0 = GemmK / GemmK1;
|
const auto GemmK0 = GemmK / GemmK1;
|
||||||
|
|
||||||
// A: input tensor
|
// A: input tensor
|
||||||
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hi_wi_c_grid_desc,
|
in_n_hi_wi_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||||
@@ -79,7 +79,7 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_hip_wip_c_grid_desc,
|
in_n_hip_wip_c_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||||
@@ -89,36 +89,36 @@ transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||||
|
|
||||||
const auto in_gemmk_gemmm_grid_desc =
|
const auto in_gemmk_gemmm_grid_desc =
|
||||||
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|
||||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||||
in_gemmk_gemmm_grid_desc,
|
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||||
make_pass_through_transform(GemmM)),
|
make_pass_through_transform(GemmM)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// B: weight tensor
|
// B: weight tensor
|
||||||
const auto wei_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
|
||||||
wei_gemmk_gemmn_grid_desc,
|
transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||||
make_pass_through_transform(GemmN)),
|
make_pass_through_transform(GemmN)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||||
|
|
||||||
// C: output tensor
|
// C: output tensor
|
||||||
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
|
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -24,9 +24,9 @@ template <typename... Wei,
|
|||||||
typename C0Type>
|
typename C0Type>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -68,15 +68,15 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
|||||||
const auto C1 = C / C0;
|
const auto C1 = C / C0;
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_gk0_gm0_gm1_gk1_grid_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||||
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
||||||
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
|
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
|
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hi_wi_grid_desc,
|
in_n_c_hi_wi_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -85,7 +85,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hip_wip_grid_desc,
|
in_n_c_hip_wip_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||||
make_unmerge_transform(make_tuple(C0, C1)),
|
make_unmerge_transform(make_tuple(C0, C1)),
|
||||||
@@ -94,7 +94,7 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||||
|
|
||||||
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor(
|
||||||
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
|
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
|
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
|
||||||
make_pass_through_transform(N0),
|
make_pass_through_transform(N0),
|
||||||
@@ -105,17 +105,17 @@ transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
|||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_n_k_howo_grid_desc =
|
const auto out_n_k_howo_grid_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
|
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo));
|
||||||
|
|
||||||
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_n0_n1_1_k_howo_grid_desc =
|
||||||
out_n_k_howo_grid_desc,
|
transform_tensor_descriptor(out_n_k_howo_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||||
make_unmerge_transform(make_tuple(I1, K)),
|
make_unmerge_transform(make_tuple(I1, K)),
|
||||||
make_pass_through_transform(Ho * Wo)),
|
make_pass_through_transform(Ho * Wo)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||||
|
|
||||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor(
|
||||||
out_n0_n1_1_k_howo_grid_desc,
|
out_n0_n1_1_k_howo_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(I1),
|
make_tuple(make_pass_through_transform(I1),
|
||||||
make_pass_through_transform(K),
|
make_pass_through_transform(K),
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP
|
#ifndef CK_MULTI_INDEX_TRANSFORM_HPP
|
||||||
#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP
|
#define CK_MULTI_INDEX_TRANSFORM_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "multi_index.hpp"
|
#include "multi_index.hpp"
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
template <typename LowLength>
|
template <typename LowLength>
|
||||||
struct DynamicPassThrough
|
struct PassThrough
|
||||||
{
|
{
|
||||||
using LowerIndex = MultiIndex<1>;
|
using LowerIndex = MultiIndex<1>;
|
||||||
using UpperIndex = MultiIndex<1>;
|
using UpperIndex = MultiIndex<1>;
|
||||||
@@ -16,9 +16,9 @@ struct DynamicPassThrough
|
|||||||
|
|
||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicPassThrough() = default;
|
__host__ __device__ constexpr PassThrough() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicPassThrough(const LowLength& low_length)
|
__host__ __device__ constexpr PassThrough(const LowLength& low_length)
|
||||||
: up_lengths_{make_tuple(low_length)}
|
: up_lengths_{make_tuple(low_length)}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -82,33 +82,36 @@ struct DynamicPassThrough
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicPassThrough, ");
|
printf("PassThrough, ");
|
||||||
printf("up_lengths_");
|
printf("up_lengths_");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
printf("}");
|
printf("}");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
|
template <typename LowLength,
|
||||||
struct DynamicPad
|
typename LeftPadLength,
|
||||||
|
typename RightPadLength,
|
||||||
|
bool SkipIsValidCheck = false>
|
||||||
|
struct Pad
|
||||||
{
|
{
|
||||||
using LowerIndex = MultiIndex<1>;
|
using LowerIndex = MultiIndex<1>;
|
||||||
using UpperIndex = MultiIndex<1>;
|
using UpperIndex = MultiIndex<1>;
|
||||||
|
|
||||||
using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{} + RightPad{}));
|
using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
|
||||||
|
|
||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
LeftPad left_pad_;
|
LeftPadLength left_pad_length_;
|
||||||
RightPad right_pad_;
|
RightPadLength right_pad_length_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicPad() = default;
|
__host__ __device__ constexpr Pad() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicPad(const LowLength& low_length,
|
__host__ __device__ constexpr Pad(const LowLength& low_length,
|
||||||
const LeftPad& left_pad,
|
const LeftPadLength& left_pad_length,
|
||||||
const RightPad& right_pad)
|
const RightPadLength& right_pad_length)
|
||||||
: up_lengths_{make_tuple(low_length + left_pad + right_pad)},
|
: up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
|
||||||
left_pad_{left_pad},
|
left_pad_length_{left_pad_length},
|
||||||
right_pad_{right_pad}
|
right_pad_length_{right_pad_length}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +128,7 @@ struct DynamicPad
|
|||||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
||||||
"wrong! inconsistent # of dimension");
|
"wrong! inconsistent # of dimension");
|
||||||
|
|
||||||
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_;
|
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LowIdxDiff,
|
template <typename LowIdxDiff,
|
||||||
@@ -161,45 +164,46 @@ struct DynamicPad
|
|||||||
__host__ __device__ constexpr bool
|
__host__ __device__ constexpr bool
|
||||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
||||||
{
|
{
|
||||||
return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) &&
|
return SkipIsValidCheck ||
|
||||||
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_));
|
((idx_up[Number<0>{}] >= left_pad_length_) &&
|
||||||
|
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_));
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||||
{
|
{
|
||||||
return is_known_at_compile_time<UpLengths>::value &&
|
return is_known_at_compile_time<UpLengths>::value &&
|
||||||
is_known_at_compile_time<LeftPad>::value &&
|
is_known_at_compile_time<LeftPadLength>::value &&
|
||||||
is_known_at_compile_time<RightPad>::value;
|
is_known_at_compile_time<RightPadLength>::value;
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicPad, ");
|
printf("Pad, ");
|
||||||
printf("up_lengths_");
|
printf("up_lengths_");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
printf("left_pad_ %d", index_t{left_pad_});
|
printf("left_pad_length %d", index_t{left_pad_length_});
|
||||||
printf("right_pad_ %d", index_t{right_pad_});
|
printf("right_pad_length %d", index_t{right_pad_length_});
|
||||||
printf("}");
|
printf("}");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename LowLength, typename LeftPad, bool SkipIsValidCheck = false>
|
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||||
struct DynamicLeftPad
|
struct LeftPad
|
||||||
{
|
{
|
||||||
using LowerIndex = MultiIndex<1>;
|
using LowerIndex = MultiIndex<1>;
|
||||||
using UpperIndex = MultiIndex<1>;
|
using UpperIndex = MultiIndex<1>;
|
||||||
|
|
||||||
using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{}));
|
using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
|
||||||
|
|
||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
LeftPad left_pad_;
|
LeftPadLength left_pad_length_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicLeftPad() = default;
|
__host__ __device__ constexpr LeftPad() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicLeftPad(const LowLength& low_length,
|
__host__ __device__ constexpr LeftPad(const LowLength& low_length,
|
||||||
const LeftPad& left_pad)
|
const LeftPadLength& left_pad_length)
|
||||||
: up_lengths_{make_tuple(low_length + left_pad)}, left_pad_{left_pad}
|
: up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,7 +220,7 @@ struct DynamicLeftPad
|
|||||||
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1,
|
||||||
"wrong! inconsistent # of dimension");
|
"wrong! inconsistent # of dimension");
|
||||||
|
|
||||||
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_;
|
idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LowIdxDiff,
|
template <typename LowIdxDiff,
|
||||||
@@ -252,45 +256,45 @@ struct DynamicLeftPad
|
|||||||
__host__ __device__ constexpr bool
|
__host__ __device__ constexpr bool
|
||||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
||||||
{
|
{
|
||||||
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_);
|
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_);
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||||
{
|
{
|
||||||
return is_known_at_compile_time<UpLengths>::value &&
|
return is_known_at_compile_time<UpLengths>::value &&
|
||||||
is_known_at_compile_time<LeftPad>::value;
|
is_known_at_compile_time<LeftPadLength>::value;
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicLeftPad, ");
|
printf("LeftPad, ");
|
||||||
printf("up_lengths_");
|
printf("up_lengths_");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
printf("left_pad_ %d", index_t{left_pad_});
|
printf("left_pad_length_ %d", index_t{left_pad_length_});
|
||||||
printf("}");
|
printf("}");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename LowLength, typename RightPad, bool SkipIsValidCheck = false>
|
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
|
||||||
struct DynamicRightPad
|
struct RightPad
|
||||||
{
|
{
|
||||||
using LowerIndex = MultiIndex<1>;
|
using LowerIndex = MultiIndex<1>;
|
||||||
using UpperIndex = MultiIndex<1>;
|
using UpperIndex = MultiIndex<1>;
|
||||||
|
|
||||||
using UpLengths = decltype(make_tuple(LowLength{} + RightPad{}));
|
using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
|
||||||
|
|
||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
LowLength low_length_;
|
LowLength low_length_;
|
||||||
RightPad right_pad_;
|
RightPadLength right_pad_length_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicRightPad() = default;
|
__host__ __device__ constexpr RightPad() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicRightPad(const LowLength& low_length,
|
__host__ __device__ constexpr RightPad(const LowLength& low_length,
|
||||||
const RightPad& right_pad)
|
const RightPadLength& right_pad_length)
|
||||||
: up_lengths_{make_tuple(low_length + right_pad)},
|
: up_lengths_{make_tuple(low_length + right_pad_length)},
|
||||||
low_length_{low_length},
|
low_length_{low_length},
|
||||||
right_pad_{right_pad}
|
right_pad_length_{right_pad_length}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -350,17 +354,17 @@ struct DynamicRightPad
|
|||||||
{
|
{
|
||||||
return is_known_at_compile_time<UpLengths>::value &&
|
return is_known_at_compile_time<UpLengths>::value &&
|
||||||
is_known_at_compile_time<LowLength>::value &&
|
is_known_at_compile_time<LowLength>::value &&
|
||||||
is_known_at_compile_time<RightPad>::value;
|
is_known_at_compile_time<RightPadLength>::value;
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicRightPad, ");
|
printf("RightPad, ");
|
||||||
printf("up_lengths_");
|
printf("up_lengths_");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
printf("low_length_ %d", index_t{low_length_});
|
printf("low_length_ %d", index_t{low_length_});
|
||||||
printf("left_pad_ %d", index_t{right_pad_});
|
printf("left_pad_length_ %d", index_t{right_pad_length_});
|
||||||
printf("}");
|
printf("}");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -373,8 +377,8 @@ struct DynamicRightPad
|
|||||||
// at compile-time
|
// at compile-time
|
||||||
template <typename UpLengths,
|
template <typename UpLengths,
|
||||||
typename Coefficients,
|
typename Coefficients,
|
||||||
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||||
struct DynamicEmbed
|
struct Embed
|
||||||
{
|
{
|
||||||
static constexpr index_t NDimUp = UpLengths::Size();
|
static constexpr index_t NDimUp = UpLengths::Size();
|
||||||
|
|
||||||
@@ -384,10 +388,10 @@ struct DynamicEmbed
|
|||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
Coefficients coefficients_;
|
Coefficients coefficients_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicEmbed() = default;
|
__host__ __device__ constexpr Embed() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicEmbed(const UpLengths& up_lengths,
|
__host__ __device__ constexpr Embed(const UpLengths& up_lengths,
|
||||||
const Coefficients& coefficients)
|
const Coefficients& coefficients)
|
||||||
: up_lengths_{up_lengths}, coefficients_{coefficients}
|
: up_lengths_{up_lengths}, coefficients_{coefficients}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -458,7 +462,7 @@ struct DynamicEmbed
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicEmbed, ");
|
printf("Embed, ");
|
||||||
printf("up_lengths_ ");
|
printf("up_lengths_ ");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
printf("coefficients_ ");
|
printf("coefficients_ ");
|
||||||
@@ -470,7 +474,7 @@ struct DynamicEmbed
|
|||||||
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
|
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
|
||||||
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
|
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
|
||||||
template <typename LowLengths>
|
template <typename LowLengths>
|
||||||
struct DynamicMerge_v1_carry_check
|
struct Merge_v1_carry_check
|
||||||
{
|
{
|
||||||
static constexpr index_t NDimLow = LowLengths::Size();
|
static constexpr index_t NDimLow = LowLengths::Size();
|
||||||
|
|
||||||
@@ -487,9 +491,9 @@ struct DynamicMerge_v1_carry_check
|
|||||||
LowLengthsScan low_lengths_scan_;
|
LowLengthsScan low_lengths_scan_;
|
||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicMerge_v1_carry_check() = default;
|
__host__ __device__ constexpr Merge_v1_carry_check() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicMerge_v1_carry_check(const LowLengths& low_lengths)
|
__host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths)
|
||||||
: low_lengths_{low_lengths},
|
: low_lengths_{low_lengths},
|
||||||
low_lengths_scan_{
|
low_lengths_scan_{
|
||||||
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
|
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
|
||||||
@@ -555,7 +559,7 @@ struct DynamicMerge_v1_carry_check
|
|||||||
LowerIndex idx_low_length_minus_idx_diff_low_const;
|
LowerIndex idx_low_length_minus_idx_diff_low_const;
|
||||||
LowerIndex idx_low_length_plus_idx_diff_low_const;
|
LowerIndex idx_low_length_plus_idx_diff_low_const;
|
||||||
|
|
||||||
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||||
index_t tmp = idx_diff_up[Number<0>{}];
|
index_t tmp = idx_diff_up[Number<0>{}];
|
||||||
|
|
||||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||||
@@ -698,7 +702,7 @@ struct DynamicMerge_v1_carry_check
|
|||||||
LowerIndex idx_low_length_minus_idx_diff_low_const;
|
LowerIndex idx_low_length_minus_idx_diff_low_const;
|
||||||
LowerIndex idx_low_length_plus_idx_diff_low_const;
|
LowerIndex idx_low_length_plus_idx_diff_low_const;
|
||||||
|
|
||||||
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||||
index_t tmp = idx_diff_up[Number<0>{}];
|
index_t tmp = idx_diff_up[Number<0>{}];
|
||||||
|
|
||||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||||
@@ -838,7 +842,7 @@ struct DynamicMerge_v1_carry_check
|
|||||||
// very expensive.
|
// very expensive.
|
||||||
LowerIndex idx_diff_low_const;
|
LowerIndex idx_diff_low_const;
|
||||||
|
|
||||||
#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||||
index_t tmp = idx_diff_up[Number<0>{}];
|
index_t tmp = idx_diff_up[Number<0>{}];
|
||||||
|
|
||||||
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
|
||||||
@@ -981,7 +985,7 @@ struct DynamicMerge_v1_carry_check
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicMerge_v1_carry_check, ");
|
printf("Merge_v1_carry_check, ");
|
||||||
printf("low_lengths_ ");
|
printf("low_lengths_ ");
|
||||||
print_multi_index(low_lengths_);
|
print_multi_index(low_lengths_);
|
||||||
printf("low_lengths_scan_ ");
|
printf("low_lengths_scan_ ");
|
||||||
@@ -1025,7 +1029,7 @@ struct lambda_merge_generate_MagicDivision_calculate_magic_shift
|
|||||||
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
|
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
|
||||||
// non-negative.
|
// non-negative.
|
||||||
template <typename LowLengths>
|
template <typename LowLengths>
|
||||||
struct DynamicMerge_v2_magic_division
|
struct Merge_v2_magic_division
|
||||||
{
|
{
|
||||||
static constexpr index_t NDimLow = LowLengths::Size();
|
static constexpr index_t NDimLow = LowLengths::Size();
|
||||||
|
|
||||||
@@ -1048,9 +1052,9 @@ struct DynamicMerge_v2_magic_division
|
|||||||
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
|
LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_;
|
||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicMerge_v2_magic_division() = default;
|
__host__ __device__ constexpr Merge_v2_magic_division() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicMerge_v2_magic_division(const LowLengths& low_lengths)
|
__host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths)
|
||||||
: low_lengths_{low_lengths},
|
: low_lengths_{low_lengths},
|
||||||
low_lengths_magic_divisor_multiplier_{generate_tuple(
|
low_lengths_magic_divisor_multiplier_{generate_tuple(
|
||||||
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
|
[&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); },
|
||||||
@@ -1151,7 +1155,7 @@ struct DynamicMerge_v2_magic_division
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicMerge_v2_magic_division, ");
|
printf("Merge_v2_magic_division, ");
|
||||||
printf("low_lengths_ ");
|
printf("low_lengths_ ");
|
||||||
print_multi_index(low_lengths_);
|
print_multi_index(low_lengths_);
|
||||||
printf("low_lengths_magic_divisor_multiplier_ ");
|
printf("low_lengths_magic_divisor_multiplier_ ");
|
||||||
@@ -1177,7 +1181,7 @@ struct DynamicMerge_v2_magic_division
|
|||||||
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
|
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
|
||||||
// non-negative.
|
// non-negative.
|
||||||
template <typename LowLengths>
|
template <typename LowLengths>
|
||||||
struct DynamicMerge_v2r2_magic_division
|
struct Merge_v2r2_magic_division
|
||||||
{
|
{
|
||||||
static constexpr index_t NDimLow = LowLengths::Size();
|
static constexpr index_t NDimLow = LowLengths::Size();
|
||||||
|
|
||||||
@@ -1204,9 +1208,9 @@ struct DynamicMerge_v2r2_magic_division
|
|||||||
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_;
|
LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_;
|
||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicMerge_v2r2_magic_division() = default;
|
__host__ __device__ constexpr Merge_v2r2_magic_division() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicMerge_v2r2_magic_division(const LowLengths& low_lengths)
|
__host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths)
|
||||||
: low_lengths_{low_lengths},
|
: low_lengths_{low_lengths},
|
||||||
low_lengths_scan_{
|
low_lengths_scan_{
|
||||||
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
|
container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})},
|
||||||
@@ -1308,7 +1312,7 @@ struct DynamicMerge_v2r2_magic_division
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicMerge_v2r2_magic_division, ");
|
printf("Merge_v2r2_magic_division, ");
|
||||||
printf("low_lengths_ ");
|
printf("low_lengths_ ");
|
||||||
print_multi_index(low_lengths_);
|
print_multi_index(low_lengths_);
|
||||||
printf("low_lengths_scan ");
|
printf("low_lengths_scan ");
|
||||||
@@ -1324,7 +1328,7 @@ struct DynamicMerge_v2r2_magic_division
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
template <typename UpLengths, bool Use24BitIntegerCalculation>
|
||||||
struct DynamicUnMerge
|
struct UnMerge
|
||||||
{
|
{
|
||||||
static constexpr index_t NDimUp = UpLengths::Size();
|
static constexpr index_t NDimUp = UpLengths::Size();
|
||||||
|
|
||||||
@@ -1337,9 +1341,9 @@ struct DynamicUnMerge
|
|||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
UpLengthsScan up_lengths_scan_;
|
UpLengthsScan up_lengths_scan_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicUnMerge() = default;
|
__host__ __device__ constexpr UnMerge() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicUnMerge(const UpLengths& up_lengths)
|
__host__ __device__ constexpr UnMerge(const UpLengths& up_lengths)
|
||||||
: up_lengths_{up_lengths},
|
: up_lengths_{up_lengths},
|
||||||
up_lengths_scan_{
|
up_lengths_scan_{
|
||||||
container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})}
|
container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})}
|
||||||
@@ -1414,7 +1418,7 @@ struct DynamicUnMerge
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicUnMerge, ");
|
printf("UnMerge, ");
|
||||||
printf("up_lengths_");
|
printf("up_lengths_");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
printf("up_lengths_scan_");
|
printf("up_lengths_scan_");
|
||||||
@@ -1424,13 +1428,13 @@ struct DynamicUnMerge
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename LowerIndex>
|
template <typename LowerIndex>
|
||||||
struct DynamicFreeze
|
struct Freeze
|
||||||
{
|
{
|
||||||
LowerIndex low_idx_;
|
LowerIndex low_idx_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicFreeze() = default;
|
__host__ __device__ constexpr Freeze() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicFreeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
|
__host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
|
||||||
|
|
||||||
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
|
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; }
|
||||||
|
|
||||||
@@ -1483,22 +1487,22 @@ struct DynamicFreeze
|
|||||||
|
|
||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("DynamicFreeze");
|
printf("Freeze");
|
||||||
printf("low_idx_ %d", index_t{low_idx_});
|
printf("low_idx_ %d", index_t{low_idx_});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Insert a dangling upper dimension without lower dimension
|
// Insert a dangling upper dimension without lower dimension
|
||||||
template <typename UpperLength>
|
template <typename UpperLength>
|
||||||
struct DynamicInsert
|
struct Insert
|
||||||
{
|
{
|
||||||
using UpLengths = decltype(make_tuple(UpperLength{}));
|
using UpLengths = decltype(make_tuple(UpperLength{}));
|
||||||
|
|
||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicInsert() = default;
|
__host__ __device__ constexpr Insert() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicInsert(const UpperLength& up_length)
|
__host__ __device__ constexpr Insert(const UpperLength& up_length)
|
||||||
: up_lengths_{make_tuple(up_length)}
|
: up_lengths_{make_tuple(up_length)}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -1550,13 +1554,13 @@ struct DynamicInsert
|
|||||||
|
|
||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("DynamicInsert");
|
printf("Insert");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename VectorSize, typename UpLength>
|
template <typename VectorSize, typename UpLength>
|
||||||
struct DynamicVectorize
|
struct Vectorize
|
||||||
{
|
{
|
||||||
using LowerIndex = MultiIndex<1>;
|
using LowerIndex = MultiIndex<1>;
|
||||||
using UpperIndex = MultiIndex<1>;
|
using UpperIndex = MultiIndex<1>;
|
||||||
@@ -1566,10 +1570,10 @@ struct DynamicVectorize
|
|||||||
UpLengths up_lengths_;
|
UpLengths up_lengths_;
|
||||||
VectorSize vector_size_;
|
VectorSize vector_size_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicVectorize() = default;
|
__host__ __device__ constexpr Vectorize() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicVectorize(const VectorSize& vector_size,
|
__host__ __device__ constexpr Vectorize(const VectorSize& vector_size,
|
||||||
const UpLength& up_length)
|
const UpLength& up_length)
|
||||||
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
|
: vector_size_{vector_size}, up_lengths_{make_tuple(up_length)}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -1633,7 +1637,7 @@ struct DynamicVectorize
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicVectorize, ");
|
printf("Vectorize, ");
|
||||||
printf("up_lengths_");
|
printf("up_lengths_");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
printf("}");
|
printf("}");
|
||||||
@@ -1641,7 +1645,7 @@ struct DynamicVectorize
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||||
struct DynamicSlice
|
struct Slice
|
||||||
{
|
{
|
||||||
using LowerIndex = MultiIndex<1>;
|
using LowerIndex = MultiIndex<1>;
|
||||||
using UpperIndex = MultiIndex<1>;
|
using UpperIndex = MultiIndex<1>;
|
||||||
@@ -1652,11 +1656,11 @@ struct DynamicSlice
|
|||||||
SliceBegin slice_begin_;
|
SliceBegin slice_begin_;
|
||||||
SliceEnd slice_end_;
|
SliceEnd slice_end_;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicSlice() = default;
|
__host__ __device__ constexpr Slice() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicSlice(const LowLength&,
|
__host__ __device__ constexpr Slice(const LowLength&,
|
||||||
const SliceBegin& slice_begin,
|
const SliceBegin& slice_begin,
|
||||||
const SliceEnd& slice_end)
|
const SliceEnd& slice_end)
|
||||||
: up_lengths_{make_tuple(slice_end - slice_begin)},
|
: up_lengths_{make_tuple(slice_end - slice_begin)},
|
||||||
slice_begin_{slice_begin},
|
slice_begin_{slice_begin},
|
||||||
slice_end_{slice_end}
|
slice_end_{slice_end}
|
||||||
@@ -1724,7 +1728,7 @@ struct DynamicSlice
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicSlice, ");
|
printf("Slice, ");
|
||||||
printf("up_lengths_");
|
printf("up_lengths_");
|
||||||
print_multi_index(up_lengths_);
|
print_multi_index(up_lengths_);
|
||||||
printf("slice_begin_ %d", index_t{slice_begin_});
|
printf("slice_begin_ %d", index_t{slice_begin_});
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
||||||
#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_multi_index_transform.hpp"
|
#include "multi_index_transform.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
template <typename LowLength>
|
template <typename LowLength>
|
||||||
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
|
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
|
||||||
{
|
{
|
||||||
return DynamicPassThrough<LowLength>{low_length};
|
return PassThrough<LowLength>{low_length};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
|
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
|
||||||
@@ -19,47 +19,46 @@ make_pad_transform(const LowLength& low_length,
|
|||||||
const RightPad& right_pad,
|
const RightPad& right_pad,
|
||||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||||
{
|
{
|
||||||
return DynamicPad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{
|
return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
|
||||||
low_length, left_pad, right_pad};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LowLength, typename LeftPad, bool SkipIsValidCheck = false>
|
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||||
__host__ __device__ constexpr auto make_left_pad_transform(
|
__host__ __device__ constexpr auto make_left_pad_transform(
|
||||||
const LowLength& low_length,
|
const LowLength& low_length,
|
||||||
const LeftPad& left_pad,
|
const LeftPadLength& left_pad,
|
||||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||||
{
|
{
|
||||||
return DynamicLeftPad<LowLength, LeftPad, SkipIsValidCheck>{low_length, left_pad};
|
return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LowLength, typename RightPad, bool SkipIsValidCheck>
|
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||||
__host__ __device__ constexpr auto make_right_pad_transform(
|
__host__ __device__ constexpr auto make_right_pad_transform(
|
||||||
const LowLength& low_length,
|
const LowLength& low_length,
|
||||||
const RightPad& right_pad,
|
const RightPadLength& right_pad,
|
||||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||||
{
|
{
|
||||||
return DynamicRightPad<LowLength, RightPad, SkipIsValidCheck>{low_length, right_pad};
|
return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename UpLengths,
|
template <typename UpLengths,
|
||||||
typename Coefficients,
|
typename Coefficients,
|
||||||
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||||
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
|
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
|
||||||
const Coefficients& coefficients)
|
const Coefficients& coefficients)
|
||||||
{
|
{
|
||||||
return DynamicEmbed<UpLengths, Coefficients>{up_lengths, coefficients};
|
return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LowLengths>
|
template <typename LowLengths>
|
||||||
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
|
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
|
||||||
{
|
{
|
||||||
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||||
return DynamicMerge_v1_carry_check<LowLengths>{low_lengths};
|
return Merge_v1_carry_check<LowLengths>{low_lengths};
|
||||||
#else
|
#else
|
||||||
#if 1
|
#if 1
|
||||||
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
|
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||||
#else
|
#else
|
||||||
return DynamicMerge_v2r2_magic_division<LowLengths>{low_lengths};
|
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -68,7 +67,7 @@ template <typename LowLengths>
|
|||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
|
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
|
||||||
{
|
{
|
||||||
return DynamicMerge_v2_magic_division<LowLengths>{low_lengths};
|
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
|
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
|
||||||
@@ -76,13 +75,13 @@ __host__ __device__ constexpr auto make_unmerge_transform(
|
|||||||
const UpLengths& up_lengths,
|
const UpLengths& up_lengths,
|
||||||
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
|
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
|
||||||
{
|
{
|
||||||
return DynamicUnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
|
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LowerIndex>
|
template <typename LowerIndex>
|
||||||
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
|
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
|
||||||
{
|
{
|
||||||
return DynamicFreeze<LowerIndex>{low_idx};
|
return Freeze<LowerIndex>{low_idx};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||||
@@ -90,14 +89,14 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
|
|||||||
const SliceBegin& slice_begin,
|
const SliceBegin& slice_begin,
|
||||||
const SliceEnd& slice_end)
|
const SliceEnd& slice_end)
|
||||||
{
|
{
|
||||||
return DynamicSlice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
|
return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename VectorSize, typename UpLength>
|
template <typename VectorSize, typename UpLength>
|
||||||
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
|
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
|
||||||
const UpLength& up_length)
|
const UpLength& up_length)
|
||||||
{
|
{
|
||||||
return DynamicVectorize<VectorSize, UpLength>{vector_size, up_length};
|
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
@@ -2,8 +2,8 @@
|
|||||||
#define CK_TENSOR_ADAPTOR_HPP
|
#define CK_TENSOR_ADAPTOR_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
|
|||||||
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
|
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X,
|
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||||
typename... Xs,
|
|
||||||
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
|
||||||
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||||
{
|
{
|
||||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP
|
#ifndef CK_TENSOR_DESCRIPTOR_HPP
|
||||||
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP
|
#define CK_TENSOR_DESCRIPTOR_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_multi_index_transform.hpp"
|
#include "multi_index_transform.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
template <index_t NDimHidden, typename VisibleDimensionIds>
|
template <index_t NDimHidden, typename VisibleDimensionIds>
|
||||||
struct DynamicTensorCoordinate;
|
struct TensorCoordinate;
|
||||||
|
|
||||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||||
struct DynamicTensorCoordinateIterator;
|
struct TensorCoordinateStep;
|
||||||
|
|
||||||
// Transforms: Tuple<transforms...>
|
// Transforms: Tuple<transforms...>
|
||||||
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
|
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
|
||||||
@@ -21,7 +21,7 @@ template <typename Transforms,
|
|||||||
typename UpperDimensionIdss,
|
typename UpperDimensionIdss,
|
||||||
typename VisibleDimensionIds,
|
typename VisibleDimensionIds,
|
||||||
typename ElementSpaceSize>
|
typename ElementSpaceSize>
|
||||||
struct DynamicTensorDescriptor
|
struct TensorDescriptor
|
||||||
{
|
{
|
||||||
// TODO make these private
|
// TODO make these private
|
||||||
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
|
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
|
||||||
@@ -105,16 +105,16 @@ struct DynamicTensorDescriptor
|
|||||||
|
|
||||||
using VisibleIndex = MultiIndex<ndim_visible_>;
|
using VisibleIndex = MultiIndex<ndim_visible_>;
|
||||||
using HiddenIndex = MultiIndex<ndim_hidden_>;
|
using HiddenIndex = MultiIndex<ndim_hidden_>;
|
||||||
using Coordinate = DynamicTensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
|
using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
|
||||||
|
|
||||||
// may be index_t or Number<>
|
// may be index_t or Number<>
|
||||||
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
|
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
__host__ __device__ constexpr DynamicTensorDescriptor() = default;
|
__host__ __device__ constexpr TensorDescriptor() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms,
|
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
|
||||||
ElementSpaceSize element_space_size)
|
ElementSpaceSize element_space_size)
|
||||||
: transforms_{transforms},
|
: transforms_{transforms},
|
||||||
element_size_{InitializeElementSize(transforms)},
|
element_size_{InitializeElementSize(transforms)},
|
||||||
element_space_size_{element_space_size}
|
element_space_size_{element_space_size}
|
||||||
@@ -159,7 +159,7 @@ struct DynamicTensorDescriptor
|
|||||||
{
|
{
|
||||||
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
|
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate(*this, idx).GetOffset();
|
return make_tensor_coordinate(*this, idx).GetOffset();
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO make these private
|
// TODO make these private
|
||||||
@@ -196,7 +196,7 @@ struct DynamicTensorDescriptor
|
|||||||
__host__ __device__ void Print() const
|
__host__ __device__ void Print() const
|
||||||
{
|
{
|
||||||
printf("{");
|
printf("{");
|
||||||
printf("DynamicTensorDescriptor, ");
|
printf("TensorDescriptor, ");
|
||||||
static_for<0, ntransform_, 1>{}([&](auto i) {
|
static_for<0, ntransform_, 1>{}([&](auto i) {
|
||||||
printf("transforms: ");
|
printf("transforms: ");
|
||||||
transforms_[i].Print();
|
transforms_[i].Print();
|
||||||
@@ -217,7 +217,7 @@ struct DynamicTensorDescriptor
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <index_t NDimHidden, typename VisibleDimensionIds>
|
template <index_t NDimHidden, typename VisibleDimensionIds>
|
||||||
struct DynamicTensorCoordinate
|
struct TensorCoordinate
|
||||||
{
|
{
|
||||||
// TODO make these private
|
// TODO make these private
|
||||||
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
|
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
|
||||||
@@ -226,9 +226,9 @@ struct DynamicTensorCoordinate
|
|||||||
using VisibleIndex = MultiIndex<ndim_visible_>;
|
using VisibleIndex = MultiIndex<ndim_visible_>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
__host__ __device__ constexpr DynamicTensorCoordinate() = default;
|
__host__ __device__ constexpr TensorCoordinate() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicTensorCoordinate(const HiddenIndex& idx_hidden)
|
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
|
||||||
: idx_hidden_{idx_hidden}
|
: idx_hidden_{idx_hidden}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -252,16 +252,16 @@ struct DynamicTensorCoordinate
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||||
struct DynamicTensorCoordinateIterator
|
struct TensorCoordinateStep
|
||||||
{
|
{
|
||||||
// TODO make these private
|
// TODO make these private
|
||||||
using VisibleIndex = MultiIndex<NDimVisible>;
|
using VisibleIndex = MultiIndex<NDimVisible>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
__host__ __device__ constexpr DynamicTensorCoordinateIterator() = default;
|
__host__ __device__ constexpr TensorCoordinateStep() = default;
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicTensorCoordinateIterator(
|
__host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
|
||||||
const VisibleIndex& idx_diff_visible, const MultiIndex<NTransform>& do_transforms)
|
const MultiIndex<NTransform>& do_transforms)
|
||||||
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
|
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -283,7 +283,7 @@ struct DynamicTensorCoordinateIterator
|
|||||||
|
|
||||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||||
// doesn't have constructor, and to put it outside the scope where it is used
|
// doesn't have constructor, and to put it outside the scope where it is used
|
||||||
// (transform_dynamic_tensor_descriptor) because template cannot be defined inside a function
|
// (transform_tensor_descriptor) because template cannot be defined inside a function
|
||||||
// template
|
// template
|
||||||
template <typename NewTransforms>
|
template <typename NewTransforms>
|
||||||
struct lambda_get_up_dim_num
|
struct lambda_get_up_dim_num
|
||||||
@@ -301,10 +301,10 @@ template <typename OldTensorDescriptor,
|
|||||||
typename NewLowerDimensionOldVisibleIdss,
|
typename NewLowerDimensionOldVisibleIdss,
|
||||||
typename NewUpperDimensionNewVisibleIdss>
|
typename NewUpperDimensionNewVisibleIdss>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||||
const NewTransforms& new_transforms,
|
const NewTransforms& new_transforms,
|
||||||
NewLowerDimensionOldVisibleIdss,
|
NewLowerDimensionOldVisibleIdss,
|
||||||
NewUpperDimensionNewVisibleIdss)
|
NewUpperDimensionNewVisibleIdss)
|
||||||
{
|
{
|
||||||
// sanity check
|
// sanity check
|
||||||
{
|
{
|
||||||
@@ -376,17 +376,17 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
|||||||
|
|
||||||
const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
|
const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
|
||||||
|
|
||||||
return DynamicTensorDescriptor<remove_cv_t<decltype(all_transforms)>,
|
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
|
||||||
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
|
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
|
||||||
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
|
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
|
||||||
remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
|
remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
|
||||||
remove_cv_t<decltype(element_space_size)>>{all_transforms,
|
remove_cv_t<decltype(element_space_size)>>{all_transforms,
|
||||||
element_space_size};
|
element_space_size};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorDesc, typename VisibleIndex>
|
template <typename TensorDesc, typename VisibleIndex>
|
||||||
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDesc& tensor_desc,
|
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||||
const VisibleIndex& idx_visible)
|
const VisibleIndex& idx_visible)
|
||||||
{
|
{
|
||||||
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
||||||
"wrong! # of dimension inconsistent");
|
"wrong! # of dimension inconsistent");
|
||||||
@@ -416,14 +416,15 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
|
|||||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||||
});
|
});
|
||||||
|
|
||||||
return DynamicTensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
|
return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLowerIndexHack: Sequence<...>
|
// UpdateLowerIndexHack: Sequence<...>
|
||||||
// HACK: control UpdateLowerIndex
|
// HACK: control UpdateLowerIndex
|
||||||
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
|
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
|
||||||
__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
|
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||||
const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack)
|
const VisibleIndex& idx_diff_visible,
|
||||||
|
UpdateLowerIndexHack)
|
||||||
{
|
{
|
||||||
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
||||||
"wrong! # of dimension inconsistent");
|
"wrong! # of dimension inconsistent");
|
||||||
@@ -470,23 +471,24 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator(
|
|||||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||||
});
|
});
|
||||||
|
|
||||||
return DynamicTensorCoordinateIterator<ntransform, ndim_visible, UpdateLowerIndexHack>{
|
return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
|
||||||
idx_diff_visible, do_transforms};
|
do_transforms};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorDesc, typename VisibleIndex>
|
template <typename TensorDesc, typename VisibleIndex>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||||
make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible)
|
const VisibleIndex& idx_diff_visible)
|
||||||
{
|
{
|
||||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
|
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterator>
|
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
|
||||||
__host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||||
const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator)
|
TensorCoord& coord,
|
||||||
|
const TensorCoordStep& coord_step)
|
||||||
{
|
{
|
||||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||||
@@ -495,9 +497,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
|||||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||||
|
|
||||||
// initialize visible index diff
|
// initialize visible index diff
|
||||||
set_container_subset(idx_diff_hidden,
|
set_container_subset(
|
||||||
TensorDesc::GetVisibleDimensionIds(),
|
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
|
||||||
coord_iterator.GetVisibleIndexDiff());
|
|
||||||
|
|
||||||
// this is what needs to be updated
|
// this is what needs to be updated
|
||||||
auto& idx_hidden = coord.GetHiddenIndex();
|
auto& idx_hidden = coord.GetHiddenIndex();
|
||||||
@@ -506,13 +507,13 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
|||||||
auto idx_hidden_pick_visible =
|
auto idx_hidden_pick_visible =
|
||||||
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
|
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
|
||||||
|
|
||||||
idx_hidden_pick_visible += coord_iterator.GetIndexDiff();
|
idx_hidden_pick_visible += coord_step.GetIndexDiff();
|
||||||
|
|
||||||
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
|
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
|
||||||
|
|
||||||
// update rest of hidden index
|
// update rest of hidden index
|
||||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||||
if(coord_iterator.do_transforms_[itran])
|
if(coord_step.do_transforms_[itran])
|
||||||
{
|
{
|
||||||
const auto& tran = tensor_desc.GetTransforms().At(itran);
|
const auto& tran = tensor_desc.GetTransforms().At(itran);
|
||||||
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
||||||
@@ -524,8 +525,8 @@ __host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
|||||||
|
|
||||||
MultiIndex<dims_low.Size()> idx_diff_low;
|
MultiIndex<dims_low.Size()> idx_diff_low;
|
||||||
|
|
||||||
// HACK: control UpdateLowerIndex for DynamicMerge using hack
|
// HACK: control UpdateLowerIndex for Merge using hack
|
||||||
constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran);
|
constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
|
||||||
|
|
||||||
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
|
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
|
||||||
|
|
||||||
@@ -585,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc&
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename TensorDesc>
|
template <typename TensorDesc>
|
||||||
using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate(
|
using TensorCoordinate_t = decltype(make_tensor_coordinate(
|
||||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||||
|
|
||||||
template <typename TensorDesc>
|
template <typename TensorDesc>
|
||||||
using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator(
|
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
|
||||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP
|
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||||
#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP
|
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_multi_index_transform_helper.hpp"
|
#include "multi_index_transform_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -37,10 +37,9 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
|
|||||||
|
|
||||||
template <typename... Lengths,
|
template <typename... Lengths,
|
||||||
typename... Strides,
|
typename... Strides,
|
||||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
|
||||||
make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
|
const Tuple<Strides...>& strides)
|
||||||
const Tuple<Strides...>& strides)
|
|
||||||
{
|
{
|
||||||
constexpr index_t N = sizeof...(Lengths);
|
constexpr index_t N = sizeof...(Lengths);
|
||||||
|
|
||||||
@@ -75,12 +74,12 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
|
|||||||
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
|
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>,
|
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||||
remove_cv_t<decltype(element_space_size)>>{transforms,
|
remove_cv_t<decltype(element_space_size)>>{transforms,
|
||||||
element_space_size};
|
element_space_size};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lengths... can be:
|
// Lengths... can be:
|
||||||
@@ -88,7 +87,7 @@ make_dynamic_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
|
|||||||
// 2) Number<>, which is known at compile-time
|
// 2) Number<>, which is known at compile-time
|
||||||
template <typename... Lengths>
|
template <typename... Lengths>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
|
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||||
{
|
{
|
||||||
constexpr index_t N = sizeof...(Lengths);
|
constexpr index_t N = sizeof...(Lengths);
|
||||||
|
|
||||||
@@ -103,17 +102,17 @@ make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple<Lengths...>& lengths)
|
|||||||
|
|
||||||
const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
|
const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{});
|
||||||
|
|
||||||
return DynamicTensorDescriptor<remove_cv_t<decltype(transforms)>,
|
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||||
remove_cv_t<decltype(element_space_size)>>{transforms,
|
remove_cv_t<decltype(element_space_size)>>{transforms,
|
||||||
element_space_size};
|
element_space_size};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Lengths, typename Align>
|
template <typename... Lengths, typename Align>
|
||||||
__host__ __device__ constexpr auto
|
__host__ __device__ constexpr auto
|
||||||
make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align)
|
make_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths, Align align)
|
||||||
{
|
{
|
||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
|
|
||||||
@@ -143,7 +142,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple<Lengths...>& lengths
|
|||||||
},
|
},
|
||||||
Number<N>{});
|
Number<N>{});
|
||||||
|
|
||||||
return make_dynamic_naive_tensor_descriptor_v2(lengths, strides);
|
return make_naive_tensor_descriptor_v2(lengths, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "tensor_adaptor.hpp"
|
#include "tensor_adaptor.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
#include "threadwise_tensor_slice_transfer.hpp"
|
||||||
#include "threadwise_contraction_dlops.hpp"
|
#include "threadwise_contraction_dlops.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
@@ -22,24 +22,24 @@ namespace ck {
|
|||||||
// 2. CThreadBuffer is StaticBuffer
|
// 2. CThreadBuffer is StaticBuffer
|
||||||
// Also assume:
|
// Also assume:
|
||||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||||
template <index_t BlockSize,
|
template <
|
||||||
typename FloatA,
|
index_t BlockSize,
|
||||||
typename FloatB,
|
typename FloatA,
|
||||||
typename FloatC,
|
typename FloatB,
|
||||||
typename AKMBlockDesc,
|
typename FloatC,
|
||||||
typename BKNBlockDesc,
|
typename AKMBlockDesc,
|
||||||
index_t M1PerThreadM11,
|
typename BKNBlockDesc,
|
||||||
index_t N1PerThreadN11,
|
index_t M1PerThreadM11,
|
||||||
index_t KPerThread,
|
index_t N1PerThreadN11,
|
||||||
index_t M1N1ThreadClusterM100,
|
index_t KPerThread,
|
||||||
index_t M1N1ThreadClusterN100,
|
index_t M1N1ThreadClusterM100,
|
||||||
index_t M1N1ThreadClusterM101,
|
index_t M1N1ThreadClusterN100,
|
||||||
index_t M1N1ThreadClusterN101,
|
index_t M1N1ThreadClusterM101,
|
||||||
index_t AThreadCopyScalarPerVector_M11,
|
index_t M1N1ThreadClusterN101,
|
||||||
index_t BThreadCopyScalarPerVector_N11,
|
index_t AThreadCopyScalarPerVector_M11,
|
||||||
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
|
index_t BThreadCopyScalarPerVector_N11,
|
||||||
BKNBlockDesc::IsKnownAtCompileTime(),
|
typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||||
{
|
{
|
||||||
using AIndex = MultiIndex<3>;
|
using AIndex = MultiIndex<3>;
|
||||||
@@ -71,9 +71,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
|||||||
static constexpr index_t N0 = N / N1;
|
static constexpr index_t N0 = N / N1;
|
||||||
|
|
||||||
__host__ __device__ static constexpr auto
|
__host__ __device__ static constexpr auto
|
||||||
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc)
|
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
|
||||||
{
|
{
|
||||||
const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
|
const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
|
||||||
AKMBlockDesc{},
|
AKMBlockDesc{},
|
||||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||||
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
|
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
|
||||||
@@ -84,9 +84,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
|||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr auto
|
__host__ __device__ static constexpr auto
|
||||||
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc)
|
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
|
||||||
{
|
{
|
||||||
const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
|
const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
|
||||||
BKNBlockDesc{},
|
BKNBlockDesc{},
|
||||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||||
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
|
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
|
||||||
@@ -194,7 +194,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
|||||||
typename ABlockBuffer,
|
typename ABlockBuffer,
|
||||||
typename BBlockBuffer,
|
typename BBlockBuffer,
|
||||||
typename CThreadBuffer>
|
typename CThreadBuffer>
|
||||||
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
|
__device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
|
||||||
const ABlockBuffer& a_block_buf,
|
const ABlockBuffer& a_block_buf,
|
||||||
const BBlockBuffer& b_block_buf,
|
const BBlockBuffer& b_block_buf,
|
||||||
CThreadBuffer& c_thread_buf) const
|
CThreadBuffer& c_thread_buf) const
|
||||||
@@ -357,34 +357,32 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// A[K, M0, M1]
|
// A[K, M0, M1]
|
||||||
static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||||
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
|
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
|
||||||
|
|
||||||
// B[K, N0, N1]
|
// B[K, N0, N1]
|
||||||
static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||||
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
|
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
|
||||||
|
|
||||||
using AThreadCopy =
|
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
FloatA,
|
||||||
FloatA,
|
decltype(a_k_m0_m1_block_desc_),
|
||||||
decltype(a_k_m0_m1_block_desc_),
|
decltype(a_k_m0_m1_thread_desc_),
|
||||||
decltype(a_k_m0_m1_thread_desc_),
|
Sequence<KPerThread, 1, M1PerThreadM11>,
|
||||||
Sequence<KPerThread, 1, M1PerThreadM11>,
|
Sequence<0, 1, 2>,
|
||||||
Sequence<0, 1, 2>,
|
2,
|
||||||
2,
|
AThreadCopyScalarPerVector_M11,
|
||||||
AThreadCopyScalarPerVector_M11,
|
1>;
|
||||||
1>;
|
|
||||||
|
|
||||||
using BThreadCopy =
|
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
FloatB,
|
||||||
FloatB,
|
decltype(b_k_n0_n1_block_desc_),
|
||||||
decltype(b_k_n0_n1_block_desc_),
|
decltype(b_k_n0_n1_thread_desc_),
|
||||||
decltype(b_k_n0_n1_thread_desc_),
|
Sequence<KPerThread, 1, N1PerThreadN11>,
|
||||||
Sequence<KPerThread, 1, N1PerThreadN11>,
|
Sequence<0, 1, 2>,
|
||||||
Sequence<0, 1, 2>,
|
2,
|
||||||
2,
|
BThreadCopyScalarPerVector_N11,
|
||||||
BThreadCopyScalarPerVector_N11,
|
1>;
|
||||||
1>;
|
|
||||||
|
|
||||||
CIndex c_thread_origin_data_idx_;
|
CIndex c_thread_origin_data_idx_;
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "tensor_adaptor.hpp"
|
#include "tensor_adaptor.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||||
#include "threadwise_contraction_dlops.hpp"
|
#include "threadwise_contraction_dlops.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
@@ -38,9 +38,9 @@ template <index_t BlockSize,
|
|||||||
// BM10BN10ThreadClusterBN101, ...>
|
// BM10BN10ThreadClusterBN101, ...>
|
||||||
index_t AThreadCopyScalarPerVector_BM11,
|
index_t AThreadCopyScalarPerVector_BM11,
|
||||||
index_t BThreadCopyScalarPerVector_BN11,
|
index_t BThreadCopyScalarPerVector_BN11,
|
||||||
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||||
{
|
{
|
||||||
using AIndex = MultiIndex<3>;
|
using AIndex = MultiIndex<3>;
|
||||||
@@ -75,7 +75,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
|||||||
__host__ __device__ static constexpr auto
|
__host__ __device__ static constexpr auto
|
||||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
|
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
|
||||||
{
|
{
|
||||||
const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor(
|
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
|
||||||
a_block_desc_bk0_bm_bk1,
|
a_block_desc_bk0_bm_bk1,
|
||||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||||
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
|
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
|
||||||
@@ -89,7 +89,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
|||||||
__host__ __device__ static constexpr auto
|
__host__ __device__ static constexpr auto
|
||||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
|
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
|
||||||
{
|
{
|
||||||
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor(
|
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
|
||||||
b_block_desc_bk0_bn_bk1,
|
b_block_desc_bk0_bn_bk1,
|
||||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||||
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
|
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
|
||||||
@@ -372,15 +372,15 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
|||||||
private:
|
private:
|
||||||
// A[BK0, BM0, BM1, BK1]
|
// A[BK0, BM0, BM1, BK1]
|
||||||
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
|
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
make_naive_tensor_descriptor_packed(make_tuple(
|
||||||
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
|
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
|
||||||
|
|
||||||
// B[BK0, BN0, BN1, BK1]
|
// B[BK0, BN0, BN1, BK1]
|
||||||
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
|
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
make_naive_tensor_descriptor_packed(make_tuple(
|
||||||
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
|
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
|
||||||
|
|
||||||
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
|
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||||
FloatA,
|
FloatA,
|
||||||
FloatA,
|
FloatA,
|
||||||
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
|
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
|
||||||
@@ -390,7 +390,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
|
|||||||
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
|
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
|
||||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||||
|
|
||||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
|
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||||
FloatB,
|
FloatB,
|
||||||
FloatB,
|
FloatB,
|
||||||
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
|
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
|
||||||
|
|||||||
@@ -31,25 +31,24 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
|||||||
// HACK: fix this @Jing Zhang
|
// HACK: fix this @Jing Zhang
|
||||||
static constexpr index_t KPerThreadSubC = 4;
|
static constexpr index_t KPerThreadSubC = 4;
|
||||||
|
|
||||||
static constexpr auto a_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
|
||||||
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
|
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
|
||||||
|
|
||||||
static constexpr auto b_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||||
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||||
|
|
||||||
static constexpr auto c_thread_mtx_ = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||||
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||||
|
|
||||||
using AThreadCopy =
|
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
FloatA,
|
||||||
FloatA,
|
BlockMatrixA,
|
||||||
BlockMatrixA,
|
decltype(a_thread_mtx_),
|
||||||
decltype(a_thread_mtx_),
|
Sequence<EPerThreadLoop, KPerThreadSubC>,
|
||||||
Sequence<EPerThreadLoop, KPerThreadSubC>,
|
Sequence<0, 1>,
|
||||||
Sequence<0, 1>,
|
1,
|
||||||
1,
|
ThreadGemmADataPerRead_K,
|
||||||
ThreadGemmADataPerRead_K,
|
1>;
|
||||||
1>;
|
|
||||||
|
|
||||||
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
|
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
|
||||||
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
|
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
|
||||||
@@ -69,7 +68,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
|||||||
"wrong! K dimension not consistent\n");
|
"wrong! K dimension not consistent\n");
|
||||||
|
|
||||||
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
|
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
|
||||||
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
|
|
||||||
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
|
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
|
||||||
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
|
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
|
||||||
|
|
||||||
@@ -121,9 +119,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
|||||||
"wrong! inconsistent type");
|
"wrong! inconsistent type");
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
constexpr auto I0 = Number<0>{};
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||||
|
|
||||||
@@ -138,7 +133,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
|||||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||||
|
|
||||||
// thread A buffer for GEMM
|
// thread A buffer for GEMM
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()>
|
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
|
||||||
a_thread_buf;
|
a_thread_buf;
|
||||||
|
|
||||||
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
#include "threadwise_tensor_slice_transfer.hpp"
|
||||||
#include "xdlops_gemm.hpp"
|
#include "xdlops_gemm.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
@@ -52,7 +52,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
|||||||
const index_t waveId = thread_id / WaveSize;
|
const index_t waveId = thread_id / WaveSize;
|
||||||
const index_t laneId = thread_id % WaveSize;
|
const index_t laneId = thread_id % WaveSize;
|
||||||
const index_t waveId_m = waveId / NWaves;
|
const index_t waveId_m = waveId / NWaves;
|
||||||
const index_t waveId_n = waveId % NWaves;
|
|
||||||
|
|
||||||
if constexpr(xdlops_gemm.IsKReduction)
|
if constexpr(xdlops_gemm.IsKReduction)
|
||||||
{
|
{
|
||||||
@@ -73,7 +72,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
|||||||
const index_t thread_id = get_thread_local_1d_id();
|
const index_t thread_id = get_thread_local_1d_id();
|
||||||
const index_t waveId = thread_id / WaveSize;
|
const index_t waveId = thread_id / WaveSize;
|
||||||
const index_t laneId = thread_id % WaveSize;
|
const index_t laneId = thread_id % WaveSize;
|
||||||
const index_t waveId_m = waveId / NWaves;
|
|
||||||
const index_t waveId_n = waveId % NWaves;
|
const index_t waveId_n = waveId % NWaves;
|
||||||
|
|
||||||
if constexpr(xdlops_gemm.IsKReduction)
|
if constexpr(xdlops_gemm.IsKReduction)
|
||||||
@@ -193,35 +191,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// A[K, M]
|
// A[K, M]
|
||||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto a_thread_desc_ =
|
||||||
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||||
|
|
||||||
// B[K, N]
|
// B[K, N]
|
||||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto b_thread_desc_ =
|
||||||
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||||
|
|
||||||
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto c_thread_desc_ =
|
||||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||||
|
|
||||||
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
ABlockDesc,
|
ABlockDesc,
|
||||||
decltype(a_thread_desc_),
|
decltype(a_thread_desc_),
|
||||||
Sequence<1, MRepeat, 1, K1>,
|
Sequence<1, MRepeat, 1, K1>,
|
||||||
Sequence<0, 1, 2, 3>,
|
Sequence<0, 1, 2, 3>,
|
||||||
3,
|
3,
|
||||||
K1,
|
K1,
|
||||||
1>;
|
1>;
|
||||||
|
|
||||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
BBlockDesc,
|
BBlockDesc,
|
||||||
decltype(b_thread_desc_),
|
decltype(b_thread_desc_),
|
||||||
Sequence<1, NRepeat, 1, K1>,
|
Sequence<1, NRepeat, 1, K1>,
|
||||||
Sequence<0, 1, 2, 3>,
|
Sequence<0, 1, 2, 3>,
|
||||||
3,
|
3,
|
||||||
K1,
|
K1,
|
||||||
1>;
|
1>;
|
||||||
|
|
||||||
AThreadCopy a_thread_copy_;
|
AThreadCopy a_thread_copy_;
|
||||||
BThreadCopy b_thread_copy_;
|
BThreadCopy b_thread_copy_;
|
||||||
@@ -272,7 +270,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
|||||||
const index_t waveId = thread_id / WaveSize;
|
const index_t waveId = thread_id / WaveSize;
|
||||||
const index_t laneId = thread_id % WaveSize;
|
const index_t laneId = thread_id % WaveSize;
|
||||||
const index_t waveId_m = waveId / NWaves;
|
const index_t waveId_m = waveId / NWaves;
|
||||||
const index_t waveId_n = waveId % NWaves;
|
|
||||||
|
|
||||||
if constexpr(xdlops_gemm.IsKReduction)
|
if constexpr(xdlops_gemm.IsKReduction)
|
||||||
{
|
{
|
||||||
@@ -293,7 +290,6 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
|||||||
const index_t thread_id = get_thread_local_1d_id();
|
const index_t thread_id = get_thread_local_1d_id();
|
||||||
const index_t waveId = thread_id / WaveSize;
|
const index_t waveId = thread_id / WaveSize;
|
||||||
const index_t laneId = thread_id % WaveSize;
|
const index_t laneId = thread_id % WaveSize;
|
||||||
const index_t waveId_m = waveId / NWaves;
|
|
||||||
const index_t waveId_n = waveId % NWaves;
|
const index_t waveId_n = waveId % NWaves;
|
||||||
|
|
||||||
if constexpr(xdlops_gemm.IsKReduction)
|
if constexpr(xdlops_gemm.IsKReduction)
|
||||||
@@ -490,35 +486,35 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// A[K, M]
|
// A[K, M]
|
||||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto a_thread_desc_ =
|
||||||
make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||||
|
|
||||||
// B[K, N]
|
// B[K, N]
|
||||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto b_thread_desc_ =
|
||||||
make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||||
|
|
||||||
static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
static constexpr auto c_thread_desc_ =
|
||||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||||
|
|
||||||
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
ABlockDesc,
|
ABlockDesc,
|
||||||
decltype(a_thread_desc_),
|
decltype(a_thread_desc_),
|
||||||
Sequence<1, 1, 1, K1>,
|
Sequence<1, 1, 1, K1>,
|
||||||
Sequence<0, 1, 2, 3>,
|
Sequence<0, 1, 2, 3>,
|
||||||
3,
|
3,
|
||||||
1, // K1,
|
1, // K1,
|
||||||
1>;
|
1>;
|
||||||
|
|
||||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4<FloatAB,
|
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
BBlockDesc,
|
BBlockDesc,
|
||||||
decltype(b_thread_desc_),
|
decltype(b_thread_desc_),
|
||||||
Sequence<1, 1, 1, K1>,
|
Sequence<1, 1, 1, K1>,
|
||||||
Sequence<0, 1, 2, 3>,
|
Sequence<0, 1, 2, 3>,
|
||||||
3,
|
3,
|
||||||
1, // K1,
|
1, // K1,
|
||||||
1>;
|
1>;
|
||||||
|
|
||||||
AThreadCopy a_thread_copy_;
|
AThreadCopy a_thread_copy_;
|
||||||
BThreadCopy b_thread_copy_;
|
BThreadCopy b_thread_copy_;
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
|
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||||
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
|
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "cluster_descriptor.hpp"
|
#include "cluster_descriptor.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
#include "threadwise_tensor_slice_transfer.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
// this version does following things to avoid scratch memory issue
|
// this version does following things to avoid scratch memory issue
|
||||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||||
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||||
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||||
template <index_t BlockSize,
|
template <index_t BlockSize,
|
||||||
InMemoryDataOperationEnum_t DstInMemOp,
|
InMemoryDataOperationEnum_t DstInMemOp,
|
||||||
typename BlockSliceLengths,
|
typename BlockSliceLengths,
|
||||||
@@ -33,16 +33,16 @@ template <index_t BlockSize,
|
|||||||
index_t DstScalarStrideInVector,
|
index_t DstScalarStrideInVector,
|
||||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||||
struct BlockwiseDynamicTensorSliceTransfer_v4
|
struct BlockwiseTensorSliceTransfer_v4
|
||||||
{
|
{
|
||||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||||
|
|
||||||
using Index = MultiIndex<nDim>;
|
using Index = MultiIndex<nDim>;
|
||||||
|
|
||||||
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc,
|
__device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
|
||||||
const Index& src_block_slice_origin,
|
const Index& src_block_slice_origin,
|
||||||
const DstDesc& dst_desc,
|
const DstDesc& dst_desc,
|
||||||
const Index& dst_block_slice_origin)
|
const Index& dst_block_slice_origin)
|
||||||
: threadwise_transfer_(
|
: threadwise_transfer_(
|
||||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||||
|
|
||||||
@@ -77,15 +77,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
template <typename SrcBuffer, typename SrcStepHacks>
|
||||||
__device__ void RunRead(const SrcDesc& src_desc,
|
__device__ void
|
||||||
const SrcBuffer& src_buf,
|
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||||
{
|
{
|
||||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,18 +117,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||||
template <typename SrcMoveSliceWindowIteratorHack>
|
template <typename SrcMoveSliceWindowStepHack>
|
||||||
__device__ void
|
__device__ void
|
||||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||||
const Index& step,
|
const Index& step,
|
||||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||||
{
|
{
|
||||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||||
{
|
{
|
||||||
threadwise_transfer_.MoveSrcSliceWindow(
|
threadwise_transfer_.MoveSrcSliceWindow(
|
||||||
src_desc, step, src_move_slice_window_iterator_hack);
|
src_desc, step, src_move_slice_window_step_hack);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,22 +146,22 @@ struct BlockwiseDynamicTensorSliceTransfer_v4
|
|||||||
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||||
|
|
||||||
using ThreadwiseTransfer =
|
using ThreadwiseTransfer =
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v3<ThreadSliceLengths,
|
ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
|
||||||
DstInMemOp,
|
DstInMemOp,
|
||||||
SrcData,
|
SrcData,
|
||||||
DstData,
|
DstData,
|
||||||
SrcDesc,
|
SrcDesc,
|
||||||
DstDesc,
|
DstDesc,
|
||||||
SrcDimAccessOrder,
|
SrcDimAccessOrder,
|
||||||
DstDimAccessOrder,
|
DstDimAccessOrder,
|
||||||
SrcVectorDim,
|
SrcVectorDim,
|
||||||
DstVectorDim,
|
DstVectorDim,
|
||||||
SrcScalarPerVector,
|
SrcScalarPerVector,
|
||||||
DstScalarPerVector,
|
DstScalarPerVector,
|
||||||
SrcScalarStrideInVector,
|
SrcScalarStrideInVector,
|
||||||
DstScalarStrideInVector,
|
DstScalarStrideInVector,
|
||||||
ThreadTransferSrcResetCoordinateAfterRun,
|
ThreadTransferSrcResetCoordinateAfterRun,
|
||||||
ThreadTransferDstResetCoordinateAfterRun>;
|
ThreadTransferDstResetCoordinateAfterRun>;
|
||||||
|
|
||||||
ThreadwiseTransfer threadwise_transfer_;
|
ThreadwiseTransfer threadwise_transfer_;
|
||||||
};
|
};
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||||
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "cluster_descriptor.hpp"
|
#include "cluster_descriptor.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
// this version does following things to avoid scratch memory issue
|
// this version does following things to avoid scratch memory issue
|
||||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||||
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||||
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||||
template <index_t BlockSize,
|
template <index_t BlockSize,
|
||||||
InMemoryDataOperationEnum_t DstInMemOp,
|
InMemoryDataOperationEnum_t DstInMemOp,
|
||||||
typename BlockSliceLengths,
|
typename BlockSliceLengths,
|
||||||
@@ -31,17 +31,16 @@ template <index_t BlockSize,
|
|||||||
typename DstVectorTensorContiguousDimOrder,
|
typename DstVectorTensorContiguousDimOrder,
|
||||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||||
struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
struct BlockwiseTensorSliceTransfer_v4r1
|
||||||
{
|
{
|
||||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||||
|
|
||||||
using Index = MultiIndex<nDim>;
|
using Index = MultiIndex<nDim>;
|
||||||
|
|
||||||
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1(
|
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc,
|
||||||
const SrcDesc& src_desc,
|
const Index& src_block_slice_origin,
|
||||||
const Index& src_block_slice_origin,
|
const DstDesc& dst_desc,
|
||||||
const DstDesc& dst_desc,
|
const Index& dst_block_slice_origin)
|
||||||
const Index& dst_block_slice_origin)
|
|
||||||
: threadwise_transfer_(
|
: threadwise_transfer_(
|
||||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||||
|
|
||||||
@@ -76,15 +75,14 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
template <typename SrcBuffer, typename SrcStepHacks>
|
||||||
__device__ void RunRead(const SrcDesc& src_desc,
|
__device__ void
|
||||||
const SrcBuffer& src_buf,
|
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||||
{
|
{
|
||||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,18 +105,18 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||||
template <typename SrcMoveSliceWindowIteratorHack>
|
template <typename SrcMoveSliceWindowStepHack>
|
||||||
__device__ void
|
__device__ void
|
||||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||||
const Index& step,
|
const Index& step,
|
||||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||||
{
|
{
|
||||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||||
{
|
{
|
||||||
threadwise_transfer_.MoveSrcSliceWindow(
|
threadwise_transfer_.MoveSrcSliceWindow(
|
||||||
src_desc, step, src_move_slice_window_iterator_hack);
|
src_desc, step, src_move_slice_window_step_hack);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,20 +134,20 @@ struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
|||||||
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||||
|
|
||||||
using ThreadwiseTransfer =
|
using ThreadwiseTransfer =
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v3r1<ThreadSliceLengths,
|
ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths,
|
||||||
DstInMemOp,
|
DstInMemOp,
|
||||||
SrcData,
|
SrcData,
|
||||||
DstData,
|
DstData,
|
||||||
SrcDesc,
|
SrcDesc,
|
||||||
DstDesc,
|
DstDesc,
|
||||||
SrcDimAccessOrder,
|
SrcDimAccessOrder,
|
||||||
DstDimAccessOrder,
|
DstDimAccessOrder,
|
||||||
SrcVectorTensorLengths,
|
SrcVectorTensorLengths,
|
||||||
DstVectorTensorLengths,
|
DstVectorTensorLengths,
|
||||||
SrcVectorTensorContiguousDimOrder,
|
SrcVectorTensorContiguousDimOrder,
|
||||||
DstVectorTensorContiguousDimOrder,
|
DstVectorTensorContiguousDimOrder,
|
||||||
ThreadTransferSrcResetCoordinateAfterRun,
|
ThreadTransferSrcResetCoordinateAfterRun,
|
||||||
ThreadTransferDstResetCoordinateAfterRun>;
|
ThreadTransferDstResetCoordinateAfterRun>;
|
||||||
|
|
||||||
ThreadwiseTransfer threadwise_transfer_;
|
ThreadwiseTransfer threadwise_transfer_;
|
||||||
};
|
};
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||||
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_multi_index_transform_helper.hpp"
|
#include "multi_index_transform_helper.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||||
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
|
#include "blockwise_tensor_slice_transfer_v2.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
#include "threadwise_tensor_slice_transfer.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
#include "threadwise_tensor_slice_set.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -25,7 +25,7 @@ __global__ void
|
|||||||
#if CK_USE_LAUNCH_BOUNDS
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
#endif
|
#endif
|
||||||
kernel_dynamic_contraction_dlops_v1r2(
|
kernel_contraction_dlops_v1r2(
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
const FloatAB* __restrict__ p_a_grid,
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
FloatC* __restrict__ p_c_grid,
|
FloatC* __restrict__ p_c_grid,
|
||||||
@@ -84,12 +84,12 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
static constexpr auto I1 = Number<1>{};
|
static constexpr auto I1 = Number<1>{};
|
||||||
@@ -110,17 +110,15 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
|
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
max_lds_align);
|
||||||
max_lds_align);
|
|
||||||
|
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
|
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
max_lds_align);
|
||||||
max_lds_align);
|
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||||
@@ -201,7 +199,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||||
const auto GM10 = GM1 / GM11;
|
const auto GM10 = GM1 / GM11;
|
||||||
|
|
||||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor(
|
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
|
||||||
a_grid_desc_gk0_gm0_gm1_gk1,
|
a_grid_desc_gk0_gm0_gm1_gk1,
|
||||||
make_tuple(make_pass_through_transform(GK0),
|
make_tuple(make_pass_through_transform(GK0),
|
||||||
make_pass_through_transform(GM0),
|
make_pass_through_transform(GM0),
|
||||||
@@ -222,7 +220,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||||
const auto GN10 = GN1 / GN11;
|
const auto GN10 = GN1 / GN11;
|
||||||
|
|
||||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor(
|
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
|
||||||
b_grid_desc_gk0_gn0_gn1_gk1,
|
b_grid_desc_gk0_gn0_gn1_gk1,
|
||||||
make_tuple(make_pass_through_transform(GK0),
|
make_tuple(make_pass_through_transform(GK0),
|
||||||
make_pass_through_transform(GN0),
|
make_pass_through_transform(GN0),
|
||||||
@@ -259,7 +257,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
constexpr auto BM0 = BM / BM1;
|
constexpr auto BM0 = BM / BM1;
|
||||||
constexpr auto BN0 = BN / BN1;
|
constexpr auto BN0 = BN / BN1;
|
||||||
|
|
||||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
|
||||||
c_grid_desc_gm0_gm1_gn0_gn1,
|
c_grid_desc_gm0_gm1_gn0_gn1,
|
||||||
make_tuple(make_pass_through_transform(GM0),
|
make_tuple(make_pass_through_transform(GM0),
|
||||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||||
@@ -268,7 +266,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||||
|
|
||||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
|
||||||
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(GM10),
|
make_tuple(make_pass_through_transform(GM10),
|
||||||
make_merge_transform(make_tuple(GM0, GM11)),
|
make_merge_transform(make_tuple(GM0, GM11)),
|
||||||
@@ -277,7 +275,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor(
|
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
|
||||||
c_gm10_bm_gn10_bn_grid_desc,
|
c_gm10_bm_gn10_bn_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(GM10),
|
make_tuple(make_pass_through_transform(GM10),
|
||||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||||
@@ -356,26 +354,24 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
|
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
max_lds_align);
|
||||||
max_lds_align);
|
|
||||||
|
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
|
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
max_lds_align);
|
||||||
max_lds_align);
|
|
||||||
|
|
||||||
// A matrix in LDS memory for blockwise GEMM
|
// A matrix in LDS memory for blockwise GEMM
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
||||||
|
|
||||||
// B matrix in LDS memory for blockwise GEMM
|
// B matrix in LDS memory for blockwise GEMM
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
||||||
|
|
||||||
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
|
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
|
||||||
@@ -385,7 +381,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
"wrong!");
|
"wrong!");
|
||||||
|
|
||||||
// A matrix blockwise copy
|
// A matrix blockwise copy
|
||||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||||
@@ -409,7 +405,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
make_multi_index(0, 0, 0, 0, 0));
|
make_multi_index(0, 0, 0, 0, 0));
|
||||||
|
|
||||||
// B matrix blockwise copy
|
// B matrix blockwise copy
|
||||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||||
@@ -457,9 +453,8 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
|
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
|
||||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||||
|
|
||||||
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 =
|
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
|
||||||
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
|
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||||
@@ -475,9 +470,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||||
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
||||||
|
|
||||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||||
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
|
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
|
||||||
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
|
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
|
||||||
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
@@ -501,9 +496,9 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
// LDS double buffer: preload data into LDS
|
// LDS double buffer: preload data into LDS
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||||
@@ -520,18 +515,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
// even iteration
|
// even iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||||
@@ -546,18 +541,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
// odd iteration
|
// odd iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -576,18 +571,18 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
{
|
{
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS double buffer: load last data from device mem
|
// LDS double buffer: load last data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
|
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
|
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
// LDS double buffer: GEMM on 2nd-last data
|
// LDS double buffer: GEMM on 2nd-last data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -615,7 +610,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
// output: register to global memory
|
// output: register to global memory
|
||||||
{
|
{
|
||||||
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
make_naive_tensor_descriptor_packed(
|
||||||
make_tuple(I1,
|
make_tuple(I1,
|
||||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
|
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
|
||||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
|
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
|
||||||
@@ -627,7 +622,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||||
get_thread_local_1d_id());
|
get_thread_local_1d_id());
|
||||||
|
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
ThreadwiseTensorSliceTransfer_v1r3<
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
FloatC,
|
FloatC,
|
||||||
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||||
@@ -655,7 +650,7 @@ struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0
|
|||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
CGridIteratorHacks{});
|
CGridStepHacks{});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP
|
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||||
#define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP
|
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_multi_index_transform_helper.hpp"
|
#include "multi_index_transform_helper.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "blockwise_gemm_dlops_v2r2.hpp"
|
#include "blockwise_gemm_dlops_v2r2.hpp"
|
||||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
#include "blockwise_tensor_slice_transfer.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
#include "threadwise_tensor_slice_transfer.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
#include "threadwise_tensor_slice_set.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ __global__ void
|
|||||||
#if CK_USE_LAUNCH_BOUNDS
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
#endif
|
#endif
|
||||||
kernel_dynamic_gemm_dlops_v1r2(
|
kernel_gemm_dlops_v1r2(
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
const FloatAB* __restrict__ p_a_grid,
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
FloatC* __restrict__ p_c_grid,
|
FloatC* __restrict__ p_c_grid,
|
||||||
@@ -68,28 +68,27 @@ __global__ void
|
|||||||
#if CK_USE_LAUNCH_BOUNDS
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
#endif
|
#endif
|
||||||
kernel_dynamic_gemm_dlops_v1r2(
|
kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid,
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
FloatC* __restrict__ p_c_grid,
|
||||||
FloatC* __restrict__ p_c_grid,
|
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
|
||||||
{
|
{
|
||||||
// first cast void CONSTANT void* to void*
|
// first cast void CONSTANT void* to void*
|
||||||
// second cast void* to Desc*
|
// second cast void* to Desc*
|
||||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||||
const auto a_k_m0_m1_grid_desc =
|
const auto a_k_m0_m1_grid_desc = *reinterpret_cast<const AKM0M1GridDesc*>(
|
||||||
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
|
cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc));
|
||||||
const auto b_k_n0_n1_grid_desc =
|
const auto b_k_n0_n1_grid_desc = *reinterpret_cast<const BKN0N1GridDesc*>(
|
||||||
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
|
cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc));
|
||||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
|
||||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
|
||||||
|
|
||||||
constexpr index_t shared_block_size =
|
constexpr index_t shared_block_size =
|
||||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||||
@@ -146,12 +145,12 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
static constexpr auto I1 = Number<1>{};
|
static constexpr auto I1 = Number<1>{};
|
||||||
@@ -167,12 +166,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||||
|
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
@@ -230,7 +229,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
const auto M1 = Number<MPerBlockM1>{};
|
const auto M1 = Number<MPerBlockM1>{};
|
||||||
const auto M0 = M / M1;
|
const auto M0 = M / M1;
|
||||||
|
|
||||||
const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
|
||||||
a_k_m_grid_desc,
|
a_k_m_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
|
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
@@ -248,7 +247,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
const auto N1 = Number<NPerBlockN1>{};
|
const auto N1 = Number<NPerBlockN1>{};
|
||||||
const auto N0 = N / N1;
|
const auto N0 = N / N1;
|
||||||
|
|
||||||
const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
|
||||||
b_k_n_grid_desc,
|
b_k_n_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
|
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
@@ -277,7 +276,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
constexpr auto M10 = M1 / M11;
|
constexpr auto M10 = M1 / M11;
|
||||||
constexpr auto N10 = N1 / N11;
|
constexpr auto N10 = N1 / N11;
|
||||||
|
|
||||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||||
c_m_n_grid_desc,
|
c_m_n_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||||
@@ -352,75 +351,75 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||||
|
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
|
||||||
|
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
|
||||||
|
|
||||||
// A matrix blockwise copy
|
// A matrix blockwise copy
|
||||||
auto a_blockwise_copy =
|
auto a_blockwise_copy =
|
||||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
decltype(a_k_m0_m1_grid_desc),
|
decltype(a_k_m0_m1_grid_desc),
|
||||||
decltype(a_k_m0_m1_block_desc),
|
decltype(a_k_m0_m1_block_desc),
|
||||||
ABlockTransferSrcAccessOrder,
|
ABlockTransferSrcAccessOrder,
|
||||||
Sequence<0, 1, 2>,
|
Sequence<0, 1, 2>,
|
||||||
ABlockTransferSrcVectorDim,
|
ABlockTransferSrcVectorDim,
|
||||||
2,
|
2,
|
||||||
ABlockTransferSrcScalarPerVector,
|
ABlockTransferSrcScalarPerVector,
|
||||||
ABlockTransferDstScalarPerVector_M1,
|
ABlockTransferDstScalarPerVector_M1,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
true>(a_k_m0_m1_grid_desc,
|
true>(a_k_m0_m1_grid_desc,
|
||||||
make_multi_index(0, im0, 0),
|
make_multi_index(0, im0, 0),
|
||||||
a_k_m0_m1_block_desc,
|
a_k_m0_m1_block_desc,
|
||||||
make_multi_index(0, 0, 0));
|
make_multi_index(0, 0, 0));
|
||||||
|
|
||||||
// B matrix blockwise copy
|
// B matrix blockwise copy
|
||||||
auto b_blockwise_copy =
|
auto b_blockwise_copy =
|
||||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
decltype(b_k_n0_n1_grid_desc),
|
decltype(b_k_n0_n1_grid_desc),
|
||||||
decltype(b_k_n0_n1_block_desc),
|
decltype(b_k_n0_n1_block_desc),
|
||||||
BBlockTransferSrcAccessOrder,
|
BBlockTransferSrcAccessOrder,
|
||||||
Sequence<0, 1, 2>,
|
Sequence<0, 1, 2>,
|
||||||
BBlockTransferSrcVectorDim,
|
BBlockTransferSrcVectorDim,
|
||||||
2,
|
2,
|
||||||
BBlockTransferSrcScalarPerVector,
|
BBlockTransferSrcScalarPerVector,
|
||||||
BBlockTransferDstScalarPerVector_N1,
|
BBlockTransferDstScalarPerVector_N1,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
true>(b_k_n0_n1_grid_desc,
|
true>(b_k_n0_n1_grid_desc,
|
||||||
make_multi_index(0, in0, 0),
|
make_multi_index(0, in0, 0),
|
||||||
b_k_n0_n1_block_desc,
|
b_k_n0_n1_block_desc,
|
||||||
make_multi_index(0, 0, 0));
|
make_multi_index(0, 0, 0));
|
||||||
|
|
||||||
// GEMM definition
|
// GEMM definition
|
||||||
// c_mtx += transpose(a_mtx) * b_mtx
|
// c_mtx += transpose(a_mtx) * b_mtx
|
||||||
@@ -447,9 +446,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||||
|
|
||||||
constexpr auto c_m10_m11_n10_n11_thread_desc =
|
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
constexpr auto a_block_aligned_space_size =
|
constexpr auto a_block_aligned_space_size =
|
||||||
@@ -465,9 +463,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||||
|
|
||||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
@@ -477,15 +475,15 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||||
|
|
||||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||||
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
|
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
|
||||||
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
|
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
|
||||||
|
|
||||||
// hack to control index calculation when move slice window for A and B matrix for
|
// hack to control index calculation when move slice window for A and B matrix for
|
||||||
// threadwise copy
|
// threadwise copy
|
||||||
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
|
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
|
||||||
AGridMoveSliceWindowIteratorHacks{};
|
AGridMoveSliceWindowStepHacks{};
|
||||||
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
|
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
|
||||||
BGridMoveSliceWindowIteratorHacks{};
|
BGridMoveSliceWindowStepHacks{};
|
||||||
|
|
||||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||||
@@ -502,9 +500,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
// LDS double buffer: preload data into LDS
|
// LDS double buffer: preload data into LDS
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||||
@@ -519,22 +517,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
do
|
do
|
||||||
{
|
{
|
||||||
// even iteration
|
// even iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(
|
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||||
a_k_m0_m1_grid_desc,
|
a_block_slice_copy_step,
|
||||||
a_block_slice_copy_step,
|
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(
|
b_block_slice_copy_step,
|
||||||
b_k_n0_n1_grid_desc,
|
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||||
b_block_slice_copy_step,
|
|
||||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||||
@@ -547,22 +543,20 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||||
|
|
||||||
// odd iteration
|
// odd iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(
|
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||||
a_k_m0_m1_grid_desc,
|
a_block_slice_copy_step,
|
||||||
a_block_slice_copy_step,
|
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(
|
b_block_slice_copy_step,
|
||||||
b_k_n0_n1_grid_desc,
|
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||||
b_block_slice_copy_step,
|
|
||||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -581,18 +575,18 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
{
|
{
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
a_k_m0_m1_global_move_slice_window_iterator_hack);
|
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
b_k_n0_n1_global_move_slice_window_iterator_hack);
|
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS double buffer: load last data from device mem
|
// LDS double buffer: load last data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(
|
||||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
|
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(
|
||||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
|
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on 2nd-last data
|
// LDS double buffer: GEMM on 2nd-last data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -619,19 +613,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
|
|
||||||
// output: register to global memory
|
// output: register to global memory
|
||||||
{
|
{
|
||||||
constexpr index_t M11 =
|
|
||||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
|
||||||
constexpr index_t N11 =
|
|
||||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
|
||||||
|
|
||||||
constexpr index_t M10 = MPerBlockM1 / M11;
|
|
||||||
constexpr index_t N10 = NPerBlockN1 / N11;
|
|
||||||
|
|
||||||
constexpr index_t M111 = M1PerThreadM111;
|
|
||||||
constexpr index_t N111 = N1PerThreadN111;
|
|
||||||
|
|
||||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
make_naive_tensor_descriptor_packed(
|
||||||
make_tuple(I1,
|
make_tuple(I1,
|
||||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||||
@@ -642,7 +625,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||||
|
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
ThreadwiseTensorSliceTransfer_v1r3<
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
FloatC,
|
FloatC,
|
||||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||||
@@ -670,7 +653,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
|||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
CGridIteratorHacks{});
|
CGridStepHacks{});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP
|
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
|
||||||
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP
|
#define CK_GRIDWISE_GEMM_V1R3_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_multi_index_transform_helper.hpp"
|
#include "multi_index_transform_helper.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||||
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
|
#include "blockwise_tensor_slice_transfer_v2.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
#include "threadwise_tensor_slice_set.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ __global__ void
|
|||||||
#if CK_USE_LAUNCH_BOUNDS
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
#endif
|
#endif
|
||||||
kernel_dynamic_gemm_dlops_v1r3(
|
kernel_gemm_dlops_v1r3(
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
const FloatAB* __restrict__ p_a_grid,
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
FloatC* __restrict__ p_c_grid,
|
FloatC* __restrict__ p_c_grid,
|
||||||
@@ -68,28 +68,27 @@ __global__ void
|
|||||||
#if CK_USE_LAUNCH_BOUNDS
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
#endif
|
#endif
|
||||||
kernel_dynamic_gemm_dlops_v1r3(
|
kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid,
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
FloatC* __restrict__ p_c_grid,
|
||||||
FloatC* __restrict__ p_c_grid,
|
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
|
||||||
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
|
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
|
||||||
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
|
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
|
||||||
{
|
{
|
||||||
// first cast void CONSTANT void* to void*
|
// first cast void CONSTANT void* to void*
|
||||||
// second cast void* to Desc*
|
// second cast void* to Desc*
|
||||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||||
const auto a_k0_m0_m1_k1_grid_desc =
|
const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast<const AK0M0M1K1GridDesc*>(
|
||||||
*reinterpret_cast<const AK0M0M1K1GridDesc*>((const void*)p_a_k0_m0_m1_k1_grid_desc);
|
cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc));
|
||||||
const auto b_k0_n0_n1_k1_grid_desc =
|
const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast<const BK0N0N1K1GridDesc*>(
|
||||||
*reinterpret_cast<const BK0N0N1K1GridDesc*>((const void*)p_b_k0_n0_n1_k1_grid_desc);
|
cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc));
|
||||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
|
||||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
|
||||||
|
|
||||||
constexpr index_t shared_block_size =
|
constexpr index_t shared_block_size =
|
||||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||||
@@ -142,12 +141,12 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
static constexpr auto I1 = Number<1>{};
|
static constexpr auto I1 = Number<1>{};
|
||||||
@@ -164,12 +163,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
|
|
||||||
// TODO: check alignment
|
// TODO: check alignment
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// TODO: check alignment
|
// TODO: check alignment
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// TODO: check alignment
|
// TODO: check alignment
|
||||||
@@ -191,12 +190,12 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||||
const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2);
|
|
||||||
|
|
||||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||||
|
|
||||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||||
|
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
|
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
|
||||||
}
|
}
|
||||||
@@ -231,13 +230,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
const auto M1 = Number<MPerBlockM1>{};
|
const auto M1 = Number<MPerBlockM1>{};
|
||||||
const auto M0 = M / M1;
|
const auto M0 = M / M1;
|
||||||
|
|
||||||
const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto a_k0_m0_m1_k1_grid_desc =
|
||||||
a_k0_m_k1_grid_desc,
|
transform_tensor_descriptor(a_k0_m_k1_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(K0),
|
make_tuple(make_pass_through_transform(K0),
|
||||||
make_unmerge_transform(make_tuple(M0, M1)),
|
make_unmerge_transform(make_tuple(M0, M1)),
|
||||||
make_pass_through_transform(K1)),
|
make_pass_through_transform(K1)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
return a_k0_m0_m1_k1_grid_desc;
|
return a_k0_m0_m1_k1_grid_desc;
|
||||||
}
|
}
|
||||||
@@ -251,13 +250,13 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
const auto N1 = Number<NPerBlockN1>{};
|
const auto N1 = Number<NPerBlockN1>{};
|
||||||
const auto N0 = N / N1;
|
const auto N0 = N / N1;
|
||||||
|
|
||||||
const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto b_k0_n0_n1_k1_grid_desc =
|
||||||
b_k0_n_k1_grid_desc,
|
transform_tensor_descriptor(b_k0_n_k1_grid_desc,
|
||||||
make_tuple(make_pass_through_transform(K0),
|
make_tuple(make_pass_through_transform(K0),
|
||||||
make_unmerge_transform(make_tuple(N0, N1)),
|
make_unmerge_transform(make_tuple(N0, N1)),
|
||||||
make_pass_through_transform(K1)),
|
make_pass_through_transform(K1)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
return b_k0_n0_n1_k1_grid_desc;
|
return b_k0_n0_n1_k1_grid_desc;
|
||||||
}
|
}
|
||||||
@@ -284,7 +283,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
constexpr auto M10 = M1 / M11;
|
constexpr auto M10 = M1 / M11;
|
||||||
constexpr auto N10 = N1 / N11;
|
constexpr auto N10 = N1 / N11;
|
||||||
|
|
||||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||||
c_m_n_grid_desc,
|
c_m_n_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||||
@@ -355,23 +354,23 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
// TODO: check alignment
|
// TODO: check alignment
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// TODO: check alignment
|
// TODO: check alignment
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// TODO: check alignment
|
// TODO: check alignment
|
||||||
// A matrix in LDS memory, for blockwise GEMM
|
// A matrix in LDS memory, for blockwise GEMM
|
||||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// TODO: check alignment
|
// TODO: check alignment
|
||||||
// B matrix in LDS memory, for blockwise GEMM
|
// B matrix in LDS memory, for blockwise GEMM
|
||||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||||
|
|
||||||
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
|
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
|
||||||
@@ -381,7 +380,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
"wrong!");
|
"wrong!");
|
||||||
|
|
||||||
// A matrix blockwise copy
|
// A matrix blockwise copy
|
||||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
||||||
@@ -405,7 +404,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
make_multi_index(0, 0, 0, 0));
|
make_multi_index(0, 0, 0, 0));
|
||||||
|
|
||||||
// B matrix blockwise copy
|
// B matrix blockwise copy
|
||||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
||||||
@@ -453,9 +452,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||||
|
|
||||||
constexpr auto c_m10_m11_n10_n11_thread_desc =
|
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||||
@@ -471,9 +469,9 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||||
|
|
||||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
@@ -496,8 +494,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
|
|
||||||
// LDS double buffer: preload data into LDS
|
// LDS double buffer: preload data into LDS
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||||
@@ -516,18 +514,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
// even iteration
|
// even iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
|
||||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||||
@@ -542,18 +538,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
// odd iteration
|
// odd iteration
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
AGridMoveSliceWindowStepHacks{});
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
BGridMoveSliceWindowStepHacks{});
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS doubel buffer: load next data from device mem
|
// LDS doubel buffer: load next data from device mem
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(
|
|
||||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -570,18 +564,16 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
// LDS double buffer: tail
|
// LDS double buffer: tail
|
||||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||||
{
|
{
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(
|
||||||
a_block_slice_copy_step,
|
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
|
||||||
AGridMoveSliceWindowIteratorHacks{});
|
b_blockwise_copy.MoveSrcSliceWindow(
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
|
||||||
b_block_slice_copy_step,
|
|
||||||
BGridMoveSliceWindowIteratorHacks{});
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// LDS double buffer: load last data from device mem
|
// LDS double buffer: load last data from device mem
|
||||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||||
|
|
||||||
// LDS double buffer: GEMM on 2nd-last data
|
// LDS double buffer: GEMM on 2nd-last data
|
||||||
blockwise_gemm.Run(
|
blockwise_gemm.Run(
|
||||||
@@ -608,21 +600,8 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
|
|
||||||
// output: register to global memory
|
// output: register to global memory
|
||||||
{
|
{
|
||||||
constexpr auto M11 =
|
|
||||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
|
|
||||||
M1PerThreadM111>{};
|
|
||||||
constexpr auto N11 =
|
|
||||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
|
|
||||||
N1PerThreadN111>{};
|
|
||||||
|
|
||||||
constexpr index_t M10 = MPerBlockM1 / M11;
|
|
||||||
constexpr index_t N10 = NPerBlockN1 / N11;
|
|
||||||
|
|
||||||
constexpr index_t M111 = M1PerThreadM111;
|
|
||||||
constexpr index_t N111 = N1PerThreadN111;
|
|
||||||
|
|
||||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
make_naive_tensor_descriptor_packed(
|
||||||
make_tuple(I1,
|
make_tuple(I1,
|
||||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||||
@@ -634,7 +613,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||||
get_thread_local_1d_id());
|
get_thread_local_1d_id());
|
||||||
|
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
ThreadwiseTensorSliceTransfer_v1r3<
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
FloatC,
|
FloatC,
|
||||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||||
@@ -662,7 +641,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
|||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
CGridIteratorHacks{});
|
CGridStepHacks{});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP
|
#ifndef CK_GRIDWISE_GEMM_V2_HPP
|
||||||
#define CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP
|
#define CK_GRIDWISE_GEMM_V2_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_multi_index_transform_helper.hpp"
|
#include "multi_index_transform_helper.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
#include "blockwise_tensor_slice_transfer.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
#include "threadwise_tensor_slice_transfer.hpp"
|
||||||
#include "blockwise_gemm_dlops_v3.hpp"
|
#include "blockwise_gemm_dlops_v3.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
@@ -42,12 +42,12 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGlobalIteratorHacks,
|
typename AGlobalStepHacks,
|
||||||
typename BGlobalIteratorHacks,
|
typename BGlobalStepHacks,
|
||||||
typename CGlobalIteratorHacks,
|
typename CGlobalStepHacks,
|
||||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
typename AGlobalMoveSliceWindowStepHacks,
|
||||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
typename BGlobalMoveSliceWindowStepHacks>
|
||||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
struct GridwiseGemmDlops_km_kn_mn_v3
|
||||||
{
|
{
|
||||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||||
{
|
{
|
||||||
@@ -58,7 +58,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
@@ -102,7 +102,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
|
|
||||||
// divide block work by [M, N]
|
// divide block work by [M, N]
|
||||||
#if 0
|
#if 0
|
||||||
const auto k_block_work_num = K / Number<KPerBlock>{};
|
|
||||||
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
|
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
|
||||||
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
|
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
|
||||||
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||||
@@ -114,7 +113,6 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
||||||
#else
|
#else
|
||||||
// Hack: this force result into SGPR
|
// Hack: this force result into SGPR
|
||||||
const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock);
|
|
||||||
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
|
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
|
||||||
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
|
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
|
||||||
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||||
@@ -134,23 +132,21 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
|
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
|
||||||
|
|
||||||
constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||||
|
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_e_n_ho_wo_block_desc =
|
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
|
||||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
|
|
||||||
|
|
||||||
// c_thread_mtx definition: this is a mess
|
// c_thread_mtx definition: this is a mess
|
||||||
// TODO:: more elegent way of defining c_thread_mtx
|
// TODO:: more elegent way of defining c_thread_mtx
|
||||||
constexpr auto c_k_n_ho_wo_thread_desc =
|
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||||
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
|
||||||
|
|
||||||
auto blockwise_gemm =
|
auto blockwise_gemm =
|
||||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
|
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
|
||||||
@@ -184,47 +180,46 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
|
|
||||||
// A matrix blockwise copy
|
// A matrix blockwise copy
|
||||||
auto a_blockwise_copy =
|
auto a_blockwise_copy =
|
||||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<E, KPerBlock>,
|
Sequence<E, KPerBlock>,
|
||||||
ABlockTransferThreadSliceLengths_E_K,
|
ABlockTransferThreadSliceLengths_E_K,
|
||||||
ABlockTransferThreadClusterLengths_E_K,
|
ABlockTransferThreadClusterLengths_E_K,
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
decltype(a_e_k_global_desc),
|
decltype(a_e_k_global_desc),
|
||||||
decltype(a_e_k_desc),
|
decltype(a_e_k_desc),
|
||||||
ABlockTransferSrcAccessOrder,
|
ABlockTransferSrcAccessOrder,
|
||||||
Sequence<0, 1>,
|
Sequence<0, 1>,
|
||||||
ABlockTransferSrcVectorDim,
|
ABlockTransferSrcVectorDim,
|
||||||
1,
|
1,
|
||||||
ABlockTransferSrcScalarPerVector,
|
ABlockTransferSrcScalarPerVector,
|
||||||
ABlockTransferDstScalarPerVector_K,
|
ABlockTransferDstScalarPerVector_K,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
true>(
|
true>(a_e_k_global_desc,
|
||||||
a_e_k_global_desc,
|
make_multi_index(0, k_block_data_on_global),
|
||||||
make_multi_index(0, k_block_data_on_global),
|
a_e_k_desc,
|
||||||
a_e_k_desc,
|
make_multi_index(0, 0));
|
||||||
make_multi_index(0, 0));
|
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_thread_desc =
|
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
|
||||||
|
|
||||||
auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2<
|
auto b_threadwise_transfer =
|
||||||
FloatAB,
|
ThreadwiseTensorSliceTransfer_v2<FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
decltype(b_e_n_ho_wo_global_desc),
|
decltype(b_e_n_ho_wo_global_desc),
|
||||||
decltype(b_e_n_ho_wo_thread_desc),
|
decltype(b_e_n_ho_wo_thread_desc),
|
||||||
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
|
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
|
||||||
BBlockTransferSrcAccessOrder,
|
BBlockTransferSrcAccessOrder,
|
||||||
BBlockTransferSrcVectorDim,
|
BBlockTransferSrcVectorDim,
|
||||||
BBlockTransferSrcScalarPerVector,
|
BBlockTransferSrcScalarPerVector,
|
||||||
1,
|
1,
|
||||||
true>(b_e_n_ho_wo_global_desc,
|
true>(
|
||||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
b_e_n_ho_wo_global_desc,
|
||||||
|
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||||
|
|
||||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||||
p_shared_block, a_e_k_desc.GetElementSpaceSize());
|
p_shared_block, a_e_k_desc.GetElementSpaceSize());
|
||||||
@@ -232,44 +227,45 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
// register allocation for output
|
// register allocation for output
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||||
|
true>
|
||||||
c_thread_buf;
|
c_thread_buf;
|
||||||
|
|
||||||
// initialize output thread tensor
|
// initialize output thread tensor
|
||||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||||
decltype(c_k_n_ho_wo_thread_desc),
|
decltype(c_k_n_ho_wo_thread_desc),
|
||||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
|
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
|
||||||
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||||
|
|
||||||
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
||||||
|
|
||||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||||
constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{};
|
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
|
||||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{};
|
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
|
||||||
|
|
||||||
// hack to control index calculation when move slice window for A and B matrix for
|
// hack to control index calculation when move slice window for A and B matrix for
|
||||||
// threadwise copy
|
// threadwise copy
|
||||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack =
|
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
|
||||||
AGlobalMoveSliceWindowIteratorHacks{};
|
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
BGlobalMoveSliceWindowStepHacks{};
|
||||||
BGlobalMoveSliceWindowIteratorHacks{};
|
|
||||||
|
|
||||||
// double regsiter buffer for b
|
// double regsiter buffer for b
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||||
|
true>
|
||||||
b_thread_even_buf, b_thread_odd_buf;
|
b_thread_even_buf, b_thread_odd_buf;
|
||||||
|
|
||||||
// LDS double buffer: preload data
|
// LDS double buffer: preload data
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_iterator_hacks);
|
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
|
||||||
|
|
||||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||||
b_global_buf,
|
b_global_buf,
|
||||||
b_e_n_ho_wo_thread_desc,
|
b_e_n_ho_wo_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
b_thread_even_buf,
|
b_thread_even_buf,
|
||||||
b_e_n_ho_wo_global_iterator_hacks);
|
b_e_n_ho_wo_global_step_hacks);
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
||||||
}
|
}
|
||||||
@@ -293,7 +289,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
b_e_n_ho_wo_thread_desc,
|
b_e_n_ho_wo_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
b_thread_odd_buf,
|
b_thread_odd_buf,
|
||||||
b_e_n_ho_wo_global_iterator_hacks);
|
b_e_n_ho_wo_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
||||||
@@ -309,7 +305,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
b_e_n_ho_wo_thread_desc,
|
b_e_n_ho_wo_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
b_thread_even_buf,
|
b_thread_even_buf,
|
||||||
b_e_n_ho_wo_global_iterator_hacks);
|
b_e_n_ho_wo_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on current data
|
// LDS double buffer: GEMM on current data
|
||||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||||
@@ -332,7 +328,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
b_e_n_ho_wo_thread_desc,
|
b_e_n_ho_wo_thread_desc,
|
||||||
make_tuple(I0, I0, I0, I0),
|
make_tuple(I0, I0, I0, I0),
|
||||||
b_thread_odd_buf,
|
b_thread_odd_buf,
|
||||||
b_e_n_ho_wo_global_iterator_hacks);
|
b_e_n_ho_wo_global_step_hacks);
|
||||||
|
|
||||||
// LDS double buffer: GEMM on 2nd-last data
|
// LDS double buffer: GEMM on 2nd-last data
|
||||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||||
@@ -351,23 +347,22 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
// output: register to global memory
|
// output: register to global memory
|
||||||
{
|
{
|
||||||
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
||||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
|
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
|
||||||
|
|
||||||
const index_t k_thread_data_on_global =
|
const index_t k_thread_data_on_global =
|
||||||
k_block_data_on_global + k_thread_id * KPerThread;
|
k_block_data_on_global + k_thread_id * KPerThread;
|
||||||
|
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||||
FloatAcc,
|
FloatC,
|
||||||
FloatC,
|
decltype(c_k_n_ho_wo_thread_desc),
|
||||||
decltype(c_k_n_ho_wo_thread_desc),
|
decltype(c_k_n_ho_wo_global_desc),
|
||||||
decltype(c_k_n_ho_wo_global_desc),
|
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
|
||||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferDstScalarPerVector,
|
||||||
CThreadTransferDstScalarPerVector,
|
CGlobalMemoryDataOperation,
|
||||||
CGlobalMemoryDataOperation,
|
1,
|
||||||
1,
|
true>(
|
||||||
true>(
|
|
||||||
c_k_n_ho_wo_global_desc,
|
c_k_n_ho_wo_global_desc,
|
||||||
make_multi_index(
|
make_multi_index(
|
||||||
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
|
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
|
||||||
@@ -376,7 +371,7 @@ struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
|||||||
c_thread_buf,
|
c_thread_buf,
|
||||||
c_k_n_ho_wo_global_desc,
|
c_k_n_ho_wo_global_desc,
|
||||||
c_global_buf,
|
c_global_buf,
|
||||||
c_k_n_ho_wo_global_tensor_iterator_hacks);
|
c_k_n_ho_wo_global_tensor_step_hacks);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,14 +1,14 @@
|
|||||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP
|
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
|
||||||
#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP
|
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_multi_index_transform_helper.hpp"
|
#include "multi_index_transform_helper.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "blockwise_gemm_xdlops.hpp"
|
#include "blockwise_gemm_xdlops.hpp"
|
||||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
#include "blockwise_tensor_slice_transfer.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
#include "threadwise_tensor_slice_transfer.hpp"
|
||||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
#include "threadwise_tensor_slice_set.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -24,13 +24,13 @@ __global__ void
|
|||||||
#if CK_USE_LAUNCH_BOUNDS
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
#endif
|
#endif
|
||||||
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
FloatC* __restrict__ p_c_grid,
|
FloatC* __restrict__ p_c_grid,
|
||||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||||
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
|
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
|
||||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||||
{
|
{
|
||||||
constexpr index_t shared_block_size =
|
constexpr index_t shared_block_size =
|
||||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||||
@@ -58,25 +58,25 @@ __global__ void
|
|||||||
#if CK_USE_LAUNCH_BOUNDS
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
#endif
|
#endif
|
||||||
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
FloatC* __restrict__ p_c_grid,
|
FloatC* __restrict__ p_c_grid,
|
||||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||||
{
|
{
|
||||||
constexpr index_t shared_block_size =
|
constexpr index_t shared_block_size =
|
||||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||||
|
|
||||||
const auto a_k0_m_k1_grid_desc =
|
const auto a_k0_m_k1_grid_desc = *reinterpret_cast<const AK0MK1GridDesc*>(
|
||||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
|
||||||
const auto b_k0_n_k1_grid_desc =
|
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
|
||||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
|
||||||
const auto c_m0_m1_m2_n_grid_desc =
|
const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>(
|
||||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc));
|
||||||
const auto c_block_cluster_adaptor =
|
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
|
||||||
*reinterpret_cast<const CBlockClusterAdaptor*>((const void*)p_c_block_cluster_adaptor);
|
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
|
||||||
|
|
||||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||||
|
|
||||||
@@ -126,13 +126,13 @@ template <index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
index_t CThreadTransferSrcDstVectorDim,
|
index_t CThreadTransferSrcDstVectorDim,
|
||||||
index_t CThreadTransferDstScalarPerVector,
|
index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks,
|
typename BGridMoveSliceWindowStepHacks,
|
||||||
bool CAccessOrderMRepeatNRepeat>
|
bool CAccessOrderMRepeatNRepeat>
|
||||||
struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
static constexpr auto I1 = Number<1>{};
|
static constexpr auto I1 = Number<1>{};
|
||||||
@@ -148,12 +148,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
@@ -203,9 +203,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
__host__ __device__ static constexpr auto
|
__host__ __device__ static constexpr auto
|
||||||
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||||
{
|
{
|
||||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
|
||||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
|
||||||
|
|
||||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
||||||
|
|
||||||
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
||||||
@@ -217,10 +214,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||||
|
|
||||||
constexpr auto N0 = Number<CLayout.N1()>{};
|
|
||||||
constexpr auto N1 = Number<CLayout.N0()>{};
|
constexpr auto N1 = Number<CLayout.N0()>{};
|
||||||
|
|
||||||
const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor(
|
const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor(
|
||||||
c_m_n_grid_desc,
|
c_m_n_grid_desc,
|
||||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
|
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
|
||||||
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
|
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
|
||||||
@@ -269,11 +265,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
|
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
|
||||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||||
{
|
{
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||||
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
|
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||||
@@ -282,8 +273,6 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
|
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
|
||||||
|
|
||||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
|
||||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
|
||||||
|
|
||||||
// divide block work by [M, N]
|
// divide block work by [M, N]
|
||||||
const auto block_work_idx =
|
const auto block_work_idx =
|
||||||
@@ -301,67 +290,65 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
|
|
||||||
// A matrix in LDS memory, dst of blockwise copy
|
// A matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// B matrix in LDS memory, dst of blockwise copy
|
// B matrix in LDS memory, dst of blockwise copy
|
||||||
// be careful of LDS alignment
|
// be careful of LDS alignment
|
||||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned_v2(
|
||||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||||
|
|
||||||
// A matrix blockwise copy
|
// A matrix blockwise copy
|
||||||
auto a_blockwise_copy =
|
auto a_blockwise_copy =
|
||||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<KPerBlock, MPerBlock, K1>,
|
Sequence<KPerBlock, MPerBlock, K1>,
|
||||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
decltype(a_k0_m_k1_grid_desc),
|
decltype(a_k0_m_k1_grid_desc),
|
||||||
decltype(a_k0_m_k1_block_desc),
|
decltype(a_k0_m_k1_block_desc),
|
||||||
ABlockTransferSrcAccessOrder,
|
ABlockTransferSrcAccessOrder,
|
||||||
Sequence<1, 0, 2>,
|
Sequence<1, 0, 2>,
|
||||||
ABlockTransferSrcVectorDim,
|
ABlockTransferSrcVectorDim,
|
||||||
2,
|
2,
|
||||||
ABlockTransferSrcScalarPerVector,
|
ABlockTransferSrcScalarPerVector,
|
||||||
ABlockTransferDstScalarPerVector_K1,
|
ABlockTransferDstScalarPerVector_K1,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
true>(
|
true>(a_k0_m_k1_grid_desc,
|
||||||
a_k0_m_k1_grid_desc,
|
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
a_k0_m_k1_block_desc,
|
||||||
a_k0_m_k1_block_desc,
|
make_multi_index(0, 0, 0));
|
||||||
make_multi_index(0, 0, 0));
|
|
||||||
|
|
||||||
// B matrix blockwise copy
|
// B matrix blockwise copy
|
||||||
auto b_blockwise_copy =
|
auto b_blockwise_copy =
|
||||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||||
InMemoryDataOperationEnum_t::Set,
|
InMemoryDataOperationEnum_t::Set,
|
||||||
Sequence<KPerBlock, NPerBlock, K1>,
|
Sequence<KPerBlock, NPerBlock, K1>,
|
||||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
decltype(b_k0_n_k1_grid_desc),
|
decltype(b_k0_n_k1_grid_desc),
|
||||||
decltype(b_k0_n_k1_block_desc),
|
decltype(b_k0_n_k1_block_desc),
|
||||||
BBlockTransferSrcAccessOrder,
|
BBlockTransferSrcAccessOrder,
|
||||||
Sequence<1, 0, 2>,
|
Sequence<1, 0, 2>,
|
||||||
BBlockTransferSrcVectorDim,
|
BBlockTransferSrcVectorDim,
|
||||||
2,
|
2,
|
||||||
BBlockTransferSrcScalarPerVector,
|
BBlockTransferSrcScalarPerVector,
|
||||||
BBlockTransferDstScalarPerVector_K1,
|
BBlockTransferDstScalarPerVector_K1,
|
||||||
1,
|
1,
|
||||||
1,
|
1,
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
true>(
|
true>(b_k0_n_k1_grid_desc,
|
||||||
b_k0_n_k1_grid_desc,
|
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
b_k0_n_k1_block_desc,
|
||||||
b_k0_n_k1_block_desc,
|
make_multi_index(0, 0, 0));
|
||||||
make_multi_index(0, 0, 0));
|
|
||||||
|
|
||||||
// GEMM definition
|
// GEMM definition
|
||||||
// c_mtx += transpose(a_mtx) * b_mtx
|
// c_mtx += transpose(a_mtx) * b_mtx
|
||||||
@@ -375,7 +362,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
NPerBlock % (NPerWave * NRepeat) == 0,
|
NPerBlock % (NPerWave * NRepeat) == 0,
|
||||||
"wrong!");
|
"wrong!");
|
||||||
|
|
||||||
constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
|
||||||
a_k0_m_k1_block_desc,
|
a_k0_m_k1_block_desc,
|
||||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||||
make_unmerge_transform(
|
make_unmerge_transform(
|
||||||
@@ -384,7 +371,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
|
||||||
b_k0_n_k1_block_desc,
|
b_k0_n_k1_block_desc,
|
||||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||||
make_unmerge_transform(
|
make_unmerge_transform(
|
||||||
@@ -410,21 +397,19 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
|
|
||||||
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
|
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
|
||||||
|
|
||||||
constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
constexpr auto c_mr_nr_blk_desc =
|
||||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||||
vector_type<FloatAcc, BlkSize>,
|
vector_type<FloatAcc, BlkSize>,
|
||||||
c_mr_nr_blk_desc.GetElementSpaceSize()>
|
c_mr_nr_blk_desc.GetElementSpaceSize(),
|
||||||
|
true>
|
||||||
c_thread_buf;
|
c_thread_buf;
|
||||||
|
|
||||||
// LDS allocation for A and B: be careful of alignment
|
// LDS allocation for A and B: be careful of alignment
|
||||||
constexpr auto a_block_space_size =
|
constexpr auto a_block_space_size =
|
||||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||||
|
|
||||||
constexpr auto b_block_space_size =
|
|
||||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
|
||||||
|
|
||||||
FloatAB* p_a_block = p_shared_block;
|
FloatAB* p_a_block = p_shared_block;
|
||||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||||
|
|
||||||
@@ -432,15 +417,13 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||||
|
|
||||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||||
constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{};
|
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||||
constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{};
|
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
|
||||||
|
|
||||||
// hack to control index calculation when move slice window for A and B matrix for
|
// hack to control index calculation when move slice window for A and B matrix for
|
||||||
// threadwise copy
|
// threadwise copy
|
||||||
constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack =
|
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
|
||||||
AGridMoveSliceWindowIteratorHacks{};
|
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
|
||||||
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
|
|
||||||
BGridMoveSliceWindowIteratorHacks{};
|
|
||||||
|
|
||||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||||
@@ -449,10 +432,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
|
|
||||||
// preload data into LDS
|
// preload data into LDS
|
||||||
{
|
{
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||||
b_blockwise_copy.RunRead(
|
|
||||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
|
||||||
|
|
||||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||||
@@ -465,18 +446,16 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
{
|
{
|
||||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||||
a_block_slice_copy_step,
|
a_block_slice_copy_step,
|
||||||
a_k0_m_k1_grid_move_slice_window_iterator_hack);
|
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||||
b_block_slice_copy_step,
|
b_block_slice_copy_step,
|
||||||
b_k0_n_k1_grid_move_slice_window_iterator_hack);
|
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||||
|
|
||||||
a_blockwise_copy.RunRead(
|
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||||
a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks);
|
|
||||||
|
|
||||||
block_sync_lds();
|
block_sync_lds();
|
||||||
|
|
||||||
b_blockwise_copy.RunRead(
|
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||||
b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks);
|
|
||||||
|
|
||||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||||
|
|
||||||
@@ -506,7 +485,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
constexpr index_t N1 = CLayout.N0();
|
constexpr index_t N1 = CLayout.N0();
|
||||||
|
|
||||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
constexpr auto c_m0_m1_m2_n_thread_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<MRepeat>{},
|
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||||
Number<NRepeat>{},
|
Number<NRepeat>{},
|
||||||
Number<1>{},
|
Number<1>{},
|
||||||
Number<1>{},
|
Number<1>{},
|
||||||
@@ -515,7 +494,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
Number<M2>{},
|
Number<M2>{},
|
||||||
Number<1>{}));
|
Number<1>{}));
|
||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
|
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
|
||||||
c_blk_buf_;
|
c_blk_buf_;
|
||||||
|
|
||||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||||
@@ -542,12 +521,12 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
const index_t n_thread_data_on_grid =
|
const index_t n_thread_data_on_grid =
|
||||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||||
|
|
||||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
|
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||||
|
|
||||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||||
|
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
ThreadwiseTensorSliceTransfer_v1r3<
|
||||||
FloatC,
|
FloatC,
|
||||||
FloatC,
|
FloatC,
|
||||||
decltype(c_m0_m1_m2_n_thread_desc),
|
decltype(c_m0_m1_m2_n_thread_desc),
|
||||||
@@ -573,7 +552,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_blk_buf_,
|
c_blk_buf_,
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
@@ -581,11 +560,8 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
constexpr index_t M1 = CLayout.N1();
|
constexpr index_t M1 = CLayout.N1();
|
||||||
constexpr index_t M2 = CLayout.M0();
|
constexpr index_t M2 = CLayout.M0();
|
||||||
|
|
||||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||||
I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
|
||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, BlkSize> c_blk_buf_;
|
|
||||||
|
|
||||||
// calculate origin of thread output tensor on global memory
|
// calculate origin of thread output tensor on global memory
|
||||||
// blockwise GEMM c matrix starting index
|
// blockwise GEMM c matrix starting index
|
||||||
@@ -598,20 +574,20 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
const index_t n_thread_data_on_grid =
|
const index_t n_thread_data_on_grid =
|
||||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||||
|
|
||||||
constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{};
|
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||||
|
|
||||||
auto c_thread_copy =
|
auto c_thread_copy =
|
||||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<FloatC,
|
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
|
||||||
FloatC,
|
FloatC,
|
||||||
decltype(c_m0_m1_m2_n_thread_desc),
|
decltype(c_m0_m1_m2_n_thread_desc),
|
||||||
decltype(c_m0_m1_m2_n_grid_desc),
|
decltype(c_m0_m1_m2_n_grid_desc),
|
||||||
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
|
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
|
||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
CGlobalMemoryDataOperation,
|
CGlobalMemoryDataOperation,
|
||||||
1,
|
1,
|
||||||
true>{
|
true>{
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
make_multi_index(0,
|
make_multi_index(0,
|
||||||
0,
|
0,
|
||||||
@@ -629,7 +605,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
|
|
||||||
return c_thread_idx_;
|
return c_thread_idx_;
|
||||||
};
|
};
|
||||||
@@ -644,7 +620,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||||
@@ -657,7 +633,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||||
@@ -670,7 +646,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||||
@@ -683,7 +659,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
|||||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||||
c_m0_m1_m2_n_grid_desc,
|
c_m0_m1_m2_n_grid_desc,
|
||||||
c_grid_buf,
|
c_grid_buf,
|
||||||
c_m0_m1_m2_n_grid_tensor_iterator_hacks);
|
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
||||||
@@ -21,10 +21,10 @@ template <typename FloatA,
|
|||||||
typename TKLengths,
|
typename TKLengths,
|
||||||
typename TMLengths,
|
typename TMLengths,
|
||||||
typename TNLengths,
|
typename TNLengths,
|
||||||
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
||||||
{
|
{
|
||||||
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
|
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
|
||||||
@@ -97,10 +97,9 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
|||||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||||
|
|
||||||
amd_inner_product_dlop<FloatA, FloatB, FloatC>(
|
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
|
||||||
a_buf[Number<a_offset>{}],
|
b_buf[Number<b_offset>{}],
|
||||||
b_buf[Number<b_offset>{}],
|
c_buf(Number<c_offset>{}));
|
||||||
c_buf(Number<c_offset>{}));
|
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -124,10 +123,10 @@ template <typename FloatA,
|
|||||||
typename TKLengths,
|
typename TKLengths,
|
||||||
typename TMLengths,
|
typename TMLengths,
|
||||||
typename TNLengths,
|
typename TNLengths,
|
||||||
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
||||||
{
|
{
|
||||||
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
||||||
@@ -214,7 +213,7 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_
|
|||||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||||
|
|
||||||
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
|
inner_product<a_vector_t, b_vector_t, FloatC>(
|
||||||
a_vec.template AsType<a_vector_t>()[I0],
|
a_vec.template AsType<a_vector_t>()[I0],
|
||||||
b_vec.template AsType<b_vector_t>()[I0],
|
b_vec.template AsType<b_vector_t>()[I0],
|
||||||
c_buf(Number<c_offset>{}));
|
c_buf(Number<c_offset>{}));
|
||||||
|
|||||||
@@ -19,9 +19,9 @@ template <typename FloatA,
|
|||||||
typename CDesc,
|
typename CDesc,
|
||||||
index_t H,
|
index_t H,
|
||||||
index_t W,
|
index_t W,
|
||||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||||
CDesc::IsKnownAtCompileTime(),
|
CDesc::IsKnownAtCompileTime(),
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
struct ThreadwiseGemmDlops_km_kn_mn_v3
|
struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||||
{
|
{
|
||||||
template <typename ABuffer,
|
template <typename ABuffer,
|
||||||
@@ -57,8 +57,6 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
|
|||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
constexpr auto I0 = Number<0>{};
|
||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
constexpr auto E = ADesc{}.GetLength(I0);
|
constexpr auto E = ADesc{}.GetLength(I0);
|
||||||
constexpr auto K = ADesc{}.GetLength(I1);
|
constexpr auto K = ADesc{}.GetLength(I1);
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
|
#ifndef CK_THREADWISE_TENSOR_SET_HPP
|
||||||
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
|
#define CK_THREADWISE_TENSOR_SET_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -11,12 +11,12 @@ namespace ck {
|
|||||||
// 1. Desc is known at compile-time
|
// 1. Desc is known at compile-time
|
||||||
// 2. Buffer is StaticBuffer
|
// 2. Buffer is StaticBuffer
|
||||||
// 3. OriginIdx is known at compile-time
|
// 3. OriginIdx is known at compile-time
|
||||||
// 4. use #-iterator
|
// 4. use #-step
|
||||||
template <typename Data,
|
template <typename Data,
|
||||||
typename Desc,
|
typename Desc,
|
||||||
typename SliceLengths,
|
typename SliceLengths,
|
||||||
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
|
typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
|
||||||
struct ThreadwiseDynamicTensorSliceSet_v1
|
struct ThreadwiseTensorSliceSet_v1
|
||||||
{
|
{
|
||||||
static constexpr index_t nDim = SliceLengths::Size();
|
static constexpr index_t nDim = SliceLengths::Size();
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ struct ThreadwiseDynamicTensorSliceSet_v1
|
|||||||
constexpr auto origin_idx = to_multi_index(OriginIdx{});
|
constexpr auto origin_idx = to_multi_index(OriginIdx{});
|
||||||
|
|
||||||
static_ford<SliceLengths>{}([&](auto access_idx) {
|
static_ford<SliceLengths>{}([&](auto access_idx) {
|
||||||
constexpr auto coord = make_dynamic_tensor_coordinate(desc, origin_idx + access_idx);
|
constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
|
||||||
|
|
||||||
constexpr bool is_valid =
|
constexpr bool is_valid =
|
||||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
|
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
|
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||||
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP
|
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -57,20 +57,20 @@ template <typename SrcData,
|
|||||||
InMemoryDataOperationEnum_t DstInMemOp,
|
InMemoryDataOperationEnum_t DstInMemOp,
|
||||||
index_t DstScalarStrideInVector,
|
index_t DstScalarStrideInVector,
|
||||||
bool DstResetCoordinateAfterRun,
|
bool DstResetCoordinateAfterRun,
|
||||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||||
struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
struct ThreadwiseTensorSliceTransfer_v1r3
|
||||||
{
|
{
|
||||||
static constexpr index_t nDim = SliceLengths::Size();
|
static constexpr index_t nDim = SliceLengths::Size();
|
||||||
|
|
||||||
using Index = MultiIndex<nDim>;
|
using Index = MultiIndex<nDim>;
|
||||||
|
|
||||||
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
|
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||||
|
|
||||||
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3(
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
|
||||||
const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
const Index& dst_slice_origin_idx)
|
||||||
: dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx))
|
: dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx))
|
||||||
{
|
{
|
||||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||||
"wrong! SrcDesc need to known at compile-time");
|
"wrong! SrcDesc need to known at compile-time");
|
||||||
@@ -78,19 +78,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
|||||||
|
|
||||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||||
{
|
{
|
||||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcSliceOriginIdx,
|
template <typename SrcSliceOriginIdx,
|
||||||
typename SrcBuffer,
|
typename SrcBuffer,
|
||||||
typename DstBuffer,
|
typename DstBuffer,
|
||||||
typename DstIteratorHacks>
|
typename DstStepHacks>
|
||||||
__device__ void Run(const SrcDesc&,
|
__device__ void Run(const SrcDesc&,
|
||||||
const SrcSliceOriginIdx&,
|
const SrcSliceOriginIdx&,
|
||||||
const SrcBuffer& src_buf,
|
const SrcBuffer& src_buf,
|
||||||
const DstDesc& dst_desc,
|
const DstDesc& dst_desc,
|
||||||
DstBuffer& dst_buf,
|
DstBuffer& dst_buf,
|
||||||
const DstIteratorHacks& dst_iterator_hacks)
|
const DstStepHacks& dst_step_hacks)
|
||||||
{
|
{
|
||||||
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
static_assert(SrcDesc::IsKnownAtCompileTime(),
|
||||||
"wrong! SrcDesc need to known at compile-time");
|
"wrong! SrcDesc need to known at compile-time");
|
||||||
@@ -127,31 +127,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
|||||||
constexpr auto ordered_access_lengths =
|
constexpr auto ordered_access_lengths =
|
||||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto dst_forward_iterators = generate_tuple(
|
const auto dst_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto dst_backward_iterators = generate_tuple(
|
const auto dst_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -235,13 +235,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
|||||||
{
|
{
|
||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_forward_iterators[dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_forward_steps[dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_backward_iterators[dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_backward_steps[dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -250,10 +250,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
|||||||
// move dst coordinate back to slice origin (or not)
|
// move dst coordinate back to slice origin (or not)
|
||||||
if constexpr(DstResetCoordinateAfterRun)
|
if constexpr(DstResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto dst_reset_iterator =
|
const auto dst_reset_step =
|
||||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,11 +268,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||||
|
|
||||||
constexpr auto dst_iterator_hacks =
|
constexpr auto dst_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_iterator_hacks);
|
Run(SrcDesc{}, SrcSliceOriginIdx{}, src_buf, dst_desc, dst_buf, dst_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||||
@@ -345,10 +345,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
|||||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step =
|
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -374,20 +373,20 @@ template <typename SrcData,
|
|||||||
index_t SrcScalarPerVector,
|
index_t SrcScalarPerVector,
|
||||||
index_t SrcScalarStrideInVector,
|
index_t SrcScalarStrideInVector,
|
||||||
bool SrcResetCoordinateAfterRun,
|
bool SrcResetCoordinateAfterRun,
|
||||||
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||||
struct ThreadwiseDynamicTensorSliceTransfer_v2
|
struct ThreadwiseTensorSliceTransfer_v2
|
||||||
{
|
{
|
||||||
static constexpr index_t nDim = SliceLengths::Size();
|
static constexpr index_t nDim = SliceLengths::Size();
|
||||||
|
|
||||||
using Index = MultiIndex<nDim>;
|
using Index = MultiIndex<nDim>;
|
||||||
|
|
||||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin_idx)
|
const Index& src_slice_origin_idx)
|
||||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx))
|
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
|
||||||
{
|
{
|
||||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||||
"wrong! SrcDesc need to known at compile-time");
|
"wrong! SrcDesc need to known at compile-time");
|
||||||
@@ -395,19 +394,19 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
|||||||
|
|
||||||
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||||
{
|
{
|
||||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer,
|
template <typename SrcBuffer,
|
||||||
typename DstBuffer,
|
typename DstBuffer,
|
||||||
typename DstSliceOriginIdx,
|
typename DstSliceOriginIdx,
|
||||||
typename SrcIteratorHacks>
|
typename SrcStepHacks>
|
||||||
__device__ void Run(const SrcDesc& src_desc,
|
__device__ void Run(const SrcDesc& src_desc,
|
||||||
const SrcBuffer& src_buf,
|
const SrcBuffer& src_buf,
|
||||||
const DstDesc&,
|
const DstDesc&,
|
||||||
const DstSliceOriginIdx&,
|
const DstSliceOriginIdx&,
|
||||||
DstBuffer& dst_buf,
|
DstBuffer& dst_buf,
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
const SrcStepHacks& src_step_hacks)
|
||||||
{
|
{
|
||||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||||
"wrong! DstDesc need to known at compile-time");
|
"wrong! DstDesc need to known at compile-time");
|
||||||
@@ -442,31 +441,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
|||||||
constexpr auto ordered_access_lengths =
|
constexpr auto ordered_access_lengths =
|
||||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto src_forward_iterators = generate_tuple(
|
const auto src_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto src_backward_iterators = generate_tuple(
|
const auto src_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -548,13 +547,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
|||||||
{
|
{
|
||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_forward_iterators[dim_access_order[i]]);
|
src_desc, src_coord_, src_forward_steps[dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_backward_iterators[dim_access_order[i]]);
|
src_desc, src_coord_, src_backward_steps[dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -563,10 +562,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
|||||||
// move src coordinate back to slice origin (or not)
|
// move src coordinate back to slice origin (or not)
|
||||||
if constexpr(SrcResetCoordinateAfterRun)
|
if constexpr(SrcResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto src_reset_iterator =
|
const auto src_reset_step =
|
||||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -581,11 +580,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||||
|
|
||||||
constexpr auto src_iterator_hacks =
|
constexpr auto src_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_iterator_hacks);
|
Run(src_desc, src_buf, DstDesc{}, DstSliceOriginIdx{}, dst_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||||
@@ -658,10 +657,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step =
|
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -693,23 +691,23 @@ template <typename SliceLengths,
|
|||||||
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
||||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||||
// save addr computation
|
// save addr computation
|
||||||
struct ThreadwiseDynamicTensorSliceTransfer_v3
|
struct ThreadwiseTensorSliceTransfer_v3
|
||||||
{
|
{
|
||||||
static constexpr index_t nDim = SliceLengths::Size();
|
static constexpr index_t nDim = SliceLengths::Size();
|
||||||
using Index = MultiIndex<nDim>;
|
using Index = MultiIndex<nDim>;
|
||||||
|
|
||||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
|
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3(const SrcDesc& src_desc,
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin,
|
const Index& src_slice_origin,
|
||||||
const DstDesc& dst_desc,
|
const DstDesc& dst_desc,
|
||||||
const Index& dst_slice_origin)
|
const Index& dst_slice_origin)
|
||||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
|
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||||
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
|
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||||
{
|
{
|
||||||
// TODO: fix this
|
// TODO: fix this
|
||||||
static_assert(is_same<SrcData, DstData>::value,
|
static_assert(is_same<SrcData, DstData>::value,
|
||||||
@@ -718,18 +716,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
|
|
||||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||||
{
|
{
|
||||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||||
{
|
{
|
||||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
template <typename SrcBuffer, typename SrcStepHacks>
|
||||||
__device__ void RunRead(const SrcDesc& src_desc,
|
__device__ void
|
||||||
const SrcBuffer& src_buf,
|
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||||
@@ -757,31 +754,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
constexpr auto ordered_src_access_lengths =
|
constexpr auto ordered_src_access_lengths =
|
||||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto src_forward_iterators = generate_tuple(
|
const auto src_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto src_backward_iterators = generate_tuple(
|
const auto src_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -862,13 +859,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
{
|
{
|
||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -877,17 +874,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
// move src coordinate back to slice origin (or not)
|
// move src coordinate back to slice origin (or not)
|
||||||
if constexpr(SrcResetCoordinateAfterRun)
|
if constexpr(SrcResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto src_reset_iterator =
|
const auto src_reset_step =
|
||||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DstBuffer, typename DstIteratorHacks>
|
template <typename DstBuffer, typename DstStepHacks>
|
||||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
__device__ void
|
||||||
DstBuffer& dst_buf,
|
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||||
const DstIteratorHacks& dst_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||||
@@ -915,35 +911,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
constexpr auto ordered_dst_access_lengths =
|
constexpr auto ordered_dst_access_lengths =
|
||||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto dst_forward_iterators = generate_tuple(
|
const auto dst_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||||
|
|
||||||
return forward_iterator;
|
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto dst_backward_iterators = generate_tuple(
|
const auto dst_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||||
|
|
||||||
return backward_iterator;
|
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -1026,13 +1018,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
{
|
{
|
||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -1041,10 +1033,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
// move dst coordinate back to slice origin (or not)
|
// move dst coordinate back to slice origin (or not)
|
||||||
if constexpr(DstResetCoordinateAfterRun)
|
if constexpr(DstResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto dst_reset_iterator =
|
const auto dst_reset_step =
|
||||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1055,11 +1047,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||||
|
|
||||||
constexpr auto src_iterator_hacks =
|
constexpr auto src_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
RunRead(src_desc, src_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DstBuffer>
|
template <typename DstBuffer>
|
||||||
@@ -1069,11 +1061,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||||
|
|
||||||
constexpr auto dst_iterator_hacks =
|
constexpr auto dst_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||||
@@ -1206,18 +1198,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step =
|
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||||
template <typename SrcMoveSliceWindowIteratorHack>
|
template <typename SrcMoveSliceWindowStepHack>
|
||||||
__device__ void
|
__device__ void
|
||||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin_step_idx,
|
const Index& src_slice_origin_step_idx,
|
||||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||||
{
|
{
|
||||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||||
const auto adjusted_step_idx =
|
const auto adjusted_step_idx =
|
||||||
@@ -1225,10 +1216,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
|
const auto adjusted_step = make_tensor_coordinate_step(
|
||||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||||
@@ -1240,19 +1231,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step =
|
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr auto buffer_desc_ =
|
static constexpr auto buffer_desc_ =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{}));
|
make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{}));
|
||||||
|
|
||||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||||
|
|
||||||
SrcCoord src_coord_;
|
SrcCoord src_coord_;
|
||||||
DstCoord dst_coord_;
|
DstCoord dst_coord_;
|
||||||
@@ -1264,37 +1254,36 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
|||||||
// 2. SrcBuffer is DynamicBuffer
|
// 2. SrcBuffer is DynamicBuffer
|
||||||
// 3. src_ref_idx is known at run-time
|
// 3. src_ref_idx is known at run-time
|
||||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||||
// 5. use #-iterator
|
// 5. use #-step
|
||||||
// 2. dst:
|
// 2. dst:
|
||||||
// 1. DstDesc is known at compile-time
|
// 1. DstDesc is known at compile-time
|
||||||
// 2. DstBuffer is StaticBuffer
|
// 2. DstBuffer is StaticBuffer
|
||||||
// 3. DstOriginIdx is known at compile-time
|
// 3. DstOriginIdx is known at compile-time
|
||||||
// 4. use direct address calculation
|
// 4. use direct address calculation
|
||||||
// 3. vector access on src
|
// 3. vector access on src
|
||||||
template <
|
template <typename SrcData,
|
||||||
typename SrcData,
|
typename DstData,
|
||||||
typename DstData,
|
typename SrcDesc,
|
||||||
typename SrcDesc,
|
typename DstDesc,
|
||||||
typename DstDesc,
|
typename SliceLengths,
|
||||||
typename SliceLengths,
|
typename DimAccessOrder,
|
||||||
typename DimAccessOrder,
|
index_t SrcVectorDim,
|
||||||
index_t SrcVectorDim,
|
index_t SrcScalarPerVector,
|
||||||
index_t SrcScalarPerVector,
|
index_t SrcScalarStrideInVector,
|
||||||
index_t SrcScalarStrideInVector,
|
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
bool>::type = false>
|
||||||
bool>::type = false>
|
struct ThreadwiseTensorSliceTransfer_v4
|
||||||
struct ThreadwiseDynamicTensorSliceTransfer_v4
|
|
||||||
{
|
{
|
||||||
static constexpr index_t nDim = SliceLengths::Size();
|
static constexpr index_t nDim = SliceLengths::Size();
|
||||||
|
|
||||||
using Index = MultiIndex<nDim>;
|
using Index = MultiIndex<nDim>;
|
||||||
|
|
||||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4(const Index& src_ref_idx)
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
|
||||||
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||||
{
|
{
|
||||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||||
@@ -1390,13 +1379,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
|
|||||||
constexpr auto src_ref_to_data_disp_idx =
|
constexpr auto src_ref_to_data_disp_idx =
|
||||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||||
|
|
||||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
constexpr auto src_ref_to_data_disp_coord_step =
|
||||||
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||||
|
|
||||||
auto src_data_coord = src_ref_coord_;
|
auto src_data_coord = src_ref_coord_;
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||||
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
|
||||||
|
|
||||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||||
|
|
||||||
@@ -1435,10 +1423,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
|
|||||||
{
|
{
|
||||||
constexpr auto src_desc = SrcDesc{};
|
constexpr auto src_desc = SrcDesc{};
|
||||||
|
|
||||||
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
|
const auto src_slice_move_step_iter =
|
||||||
src_desc, to_multi_index(src_slice_move_step_idx));
|
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||||
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ template <typename SliceLengths,
|
|||||||
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
||||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||||
// save addr computation
|
// save addr computation
|
||||||
struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
struct ThreadwiseTensorSliceTransfer_v3r1
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
static constexpr auto I1 = Number<1>{};
|
static constexpr auto I1 = Number<1>{};
|
||||||
@@ -38,18 +38,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
static constexpr index_t nDim = SliceLengths::Size();
|
static constexpr index_t nDim = SliceLengths::Size();
|
||||||
using Index = MultiIndex<nDim>;
|
using Index = MultiIndex<nDim>;
|
||||||
|
|
||||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
|
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin,
|
const Index& src_slice_origin,
|
||||||
const DstDesc& dst_desc,
|
const DstDesc& dst_desc,
|
||||||
const Index& dst_slice_origin)
|
const Index& dst_slice_origin)
|
||||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
|
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||||
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
|
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||||
{
|
{
|
||||||
// TODO: fix this
|
// TODO: fix this
|
||||||
static_assert(is_same<SrcData, DstData>::value,
|
static_assert(is_same<SrcData, DstData>::value,
|
||||||
@@ -64,18 +64,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
|
|
||||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||||
{
|
{
|
||||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||||
{
|
{
|
||||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
template <typename SrcBuffer, typename SrcStepHacks>
|
||||||
__device__ void RunRead(const SrcDesc& src_desc,
|
__device__ void
|
||||||
const SrcBuffer& src_buf,
|
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||||
const SrcIteratorHacks& src_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||||
@@ -96,9 +95,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
I1),
|
I1),
|
||||||
SrcVectorTensorContiguousDimOrder{});
|
SrcVectorTensorContiguousDimOrder{});
|
||||||
|
|
||||||
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
constexpr auto src_vector_desc =
|
||||||
sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||||
|
|
||||||
// access order and lengths
|
// access order and lengths
|
||||||
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||||
@@ -108,31 +107,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
constexpr auto ordered_src_access_lengths =
|
constexpr auto ordered_src_access_lengths =
|
||||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto src_forward_iterators = generate_tuple(
|
const auto src_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto src_backward_iterators = generate_tuple(
|
const auto src_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
return make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -219,13 +218,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
{
|
{
|
||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -234,17 +233,16 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
// move src coordinate back to slice origin (or not)
|
// move src coordinate back to slice origin (or not)
|
||||||
if constexpr(SrcResetCoordinateAfterRun)
|
if constexpr(SrcResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto src_reset_iterator =
|
const auto src_reset_step =
|
||||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DstBuffer, typename DstIteratorHacks>
|
template <typename DstBuffer, typename DstStepHacks>
|
||||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
__device__ void
|
||||||
DstBuffer& dst_buf,
|
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||||
const DstIteratorHacks& dst_iterator_hacks)
|
|
||||||
{
|
{
|
||||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||||
@@ -265,9 +263,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
I1),
|
I1),
|
||||||
DstVectorTensorContiguousDimOrder{});
|
DstVectorTensorContiguousDimOrder{});
|
||||||
|
|
||||||
constexpr auto dst_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
constexpr auto dst_vector_desc =
|
||||||
sequence_to_tuple_of_number(dst_vector_tensor_lengths),
|
make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(dst_vector_tensor_lengths),
|
||||||
sequence_to_tuple_of_number(dst_vector_tensor_strides));
|
sequence_to_tuple_of_number(dst_vector_tensor_strides));
|
||||||
|
|
||||||
// dst access order and lengths
|
// dst access order and lengths
|
||||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
|
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
|
||||||
@@ -277,35 +275,31 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
constexpr auto ordered_dst_access_lengths =
|
constexpr auto ordered_dst_access_lengths =
|
||||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||||
|
|
||||||
// make forward iterators
|
// make forward steps
|
||||||
const auto dst_forward_iterators = generate_tuple(
|
const auto dst_forward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index forward_step;
|
Index forward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||||
|
|
||||||
return forward_iterator;
|
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
// make backward iterators
|
// make backward steps
|
||||||
const auto dst_backward_iterators = generate_tuple(
|
const auto dst_backward_steps = generate_tuple(
|
||||||
[&](auto i) {
|
[&](auto i) {
|
||||||
Index backward_step;
|
Index backward_step_idx;
|
||||||
|
|
||||||
static_for<0, nDim, 1>{}([&](auto j) {
|
static_for<0, nDim, 1>{}([&](auto j) {
|
||||||
backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
||||||
});
|
});
|
||||||
|
|
||||||
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator(
|
return make_tensor_coordinate_step(
|
||||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||||
|
|
||||||
return backward_iterator;
|
|
||||||
},
|
},
|
||||||
Number<nDim>{});
|
Number<nDim>{});
|
||||||
|
|
||||||
@@ -394,13 +388,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
{
|
{
|
||||||
if constexpr(forward_sweep[i])
|
if constexpr(forward_sweep[i])
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(
|
||||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -409,10 +403,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
// move dst coordinate back to slice origin (or not)
|
// move dst coordinate back to slice origin (or not)
|
||||||
if constexpr(DstResetCoordinateAfterRun)
|
if constexpr(DstResetCoordinateAfterRun)
|
||||||
{
|
{
|
||||||
const auto dst_reset_iterator =
|
const auto dst_reset_step =
|
||||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -423,11 +417,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||||
|
|
||||||
constexpr auto src_iterator_hacks =
|
constexpr auto src_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
RunRead(src_desc, src_buf, src_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DstBuffer>
|
template <typename DstBuffer>
|
||||||
@@ -437,11 +431,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
|
|
||||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||||
|
|
||||||
constexpr auto dst_iterator_hacks =
|
constexpr auto dst_step_hacks =
|
||||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||||
|
|
||||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||||
@@ -564,18 +558,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step =
|
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||||
template <typename SrcMoveSliceWindowIteratorHack>
|
template <typename SrcMoveSliceWindowStepHack>
|
||||||
__device__ void
|
__device__ void
|
||||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||||
const Index& src_slice_origin_step_idx,
|
const Index& src_slice_origin_step_idx,
|
||||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||||
{
|
{
|
||||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||||
const auto adjusted_step_idx =
|
const auto adjusted_step_idx =
|
||||||
@@ -583,10 +576,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
|
const auto adjusted_step = make_tensor_coordinate_step(
|
||||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||||
@@ -598,19 +591,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||||
|
|
||||||
// is it OK to construct a new step every time?
|
// is it OK to construct a new step every time?
|
||||||
const auto adjusted_step =
|
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr auto buffer_desc_ =
|
static constexpr auto buffer_desc_ =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{}));
|
make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{}));
|
||||||
|
|
||||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||||
|
|
||||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||||
|
|
||||||
SrcCoord src_coord_;
|
SrcCoord src_coord_;
|
||||||
DstCoord dst_coord_;
|
DstCoord dst_coord_;
|
||||||
@@ -622,25 +614,24 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
|||||||
// 2. SrcBuffer is DynamicBuffer
|
// 2. SrcBuffer is DynamicBuffer
|
||||||
// 3. src_ref_idx is known at run-time
|
// 3. src_ref_idx is known at run-time
|
||||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||||
// 5. use #-iterator
|
// 5. use #-step
|
||||||
// 2. dst:
|
// 2. dst:
|
||||||
// 1. DstDesc is known at compile-time
|
// 1. DstDesc is known at compile-time
|
||||||
// 2. DstBuffer is StaticBuffer
|
// 2. DstBuffer is StaticBuffer
|
||||||
// 3. DstOriginIdx is known at compile-time
|
// 3. DstOriginIdx is known at compile-time
|
||||||
// 4. use direct address calculation
|
// 4. use direct address calculation
|
||||||
// 3. vector access on src
|
// 3. vector access on src
|
||||||
template <
|
template <typename SrcData,
|
||||||
typename SrcData,
|
typename DstData,
|
||||||
typename DstData,
|
typename SrcDesc,
|
||||||
typename SrcDesc,
|
typename DstDesc,
|
||||||
typename DstDesc,
|
typename SliceLengths,
|
||||||
typename SliceLengths,
|
typename DimAccessOrder,
|
||||||
typename DimAccessOrder,
|
typename SrcVectorTensorLengths,
|
||||||
typename SrcVectorTensorLengths,
|
typename SrcVectorTensorContiguousDimOrder,
|
||||||
typename SrcVectorTensorContiguousDimOrder,
|
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
bool>::type = false>
|
||||||
bool>::type = false>
|
struct ThreadwiseTensorSliceTransfer_v4r1
|
||||||
struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
|
||||||
{
|
{
|
||||||
static constexpr auto I0 = Number<0>{};
|
static constexpr auto I0 = Number<0>{};
|
||||||
static constexpr auto I1 = Number<1>{};
|
static constexpr auto I1 = Number<1>{};
|
||||||
@@ -649,12 +640,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
|||||||
|
|
||||||
using Index = MultiIndex<nDim>;
|
using Index = MultiIndex<nDim>;
|
||||||
|
|
||||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||||
|
|
||||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
||||||
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||||
{
|
{
|
||||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||||
@@ -712,9 +703,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
|||||||
I1),
|
I1),
|
||||||
SrcVectorTensorContiguousDimOrder{});
|
SrcVectorTensorContiguousDimOrder{});
|
||||||
|
|
||||||
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
constexpr auto src_vector_desc =
|
||||||
sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
make_naive_tensor_descriptor_v2(sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||||
|
|
||||||
// access order and lengths
|
// access order and lengths
|
||||||
constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||||
@@ -734,13 +725,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
|||||||
constexpr auto src_ref_to_data_disp_idx =
|
constexpr auto src_ref_to_data_disp_idx =
|
||||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||||
|
|
||||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
constexpr auto src_ref_to_data_disp_coord_step =
|
||||||
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||||
|
|
||||||
auto src_data_coord = src_ref_coord_;
|
auto src_data_coord = src_ref_coord_;
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(
|
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||||
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
|
||||||
|
|
||||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||||
|
|
||||||
@@ -775,10 +765,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
|||||||
{
|
{
|
||||||
constexpr auto src_desc = SrcDesc{};
|
constexpr auto src_desc = SrcDesc{};
|
||||||
|
|
||||||
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
|
const auto src_slice_move_step_iter =
|
||||||
src_desc, to_multi_index(src_slice_move_step_idx));
|
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||||
|
|
||||||
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -350,8 +350,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
|
|||||||
class FloatC>
|
class FloatC>
|
||||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||||
{
|
{
|
||||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||||
|
|
||||||
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
|
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
|
||||||
p_a, p_b, reg_c);
|
p_a, p_b, reg_c);
|
||||||
@@ -384,8 +384,8 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
|
|||||||
class FloatC>
|
class FloatC>
|
||||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||||
{
|
{
|
||||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||||
|
|
||||||
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
|
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
|
||||||
}
|
}
|
||||||
@@ -417,8 +417,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
|
|||||||
class FloatC>
|
class FloatC>
|
||||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||||
{
|
{
|
||||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||||
|
|
||||||
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
|
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
|
||||||
}
|
}
|
||||||
@@ -450,8 +450,8 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
|
|||||||
class FloatC>
|
class FloatC>
|
||||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||||
{
|
{
|
||||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||||
|
|
||||||
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
|
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
|
||||||
}
|
}
|
||||||
@@ -483,8 +483,8 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
|
|||||||
class FloatC>
|
class FloatC>
|
||||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||||
{
|
{
|
||||||
const auto p_a = reinterpret_cast<const ushort2_t*>(a);
|
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||||
const auto p_b = reinterpret_cast<const ushort2_t*>(b);
|
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||||
|
|
||||||
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
|
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
|
||||||
}
|
}
|
||||||
|
|||||||
44
composable_kernel/include/utility/amd_address_space.hpp
Normal file
44
composable_kernel/include/utility/amd_address_space.hpp
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
#ifndef CK_AMD_ADDRESS_SPACE_HPP
|
||||||
|
#define CK_AMD_ADDRESS_SPACE_HPP
|
||||||
|
|
||||||
|
#include "config.hpp"
|
||||||
|
#include "c_style_pointer_cast.hpp"
|
||||||
|
|
||||||
|
// Address Space for AMDGCN
|
||||||
|
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
|
||||||
|
enum AddressSpaceEnum_t
|
||||||
|
{
|
||||||
|
Generic,
|
||||||
|
Global,
|
||||||
|
Lds,
|
||||||
|
Sgpr,
|
||||||
|
Vgpr,
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
|
||||||
|
{
|
||||||
|
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
|
||||||
|
// only c-style pointer cast seems be able to be compiled
|
||||||
|
#pragma clang diagnostic push
|
||||||
|
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||||
|
return (T*)p; // NOLINT(old-style-cast)
|
||||||
|
#pragma clang diagnostic pop
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
|
||||||
|
{
|
||||||
|
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
|
||||||
|
// only c-style pointer cast seems be able to be compiled
|
||||||
|
#pragma clang diagnostic push
|
||||||
|
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||||
|
return (T CONSTANT*)p; // NOLINT(old-style-cast)
|
||||||
|
#pragma clang diagnostic pop
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ck
|
||||||
|
#endif
|
||||||
@@ -1,34 +1,34 @@
|
|||||||
#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP
|
#ifndef CK_AMD_BUFFER_ADDRESSING_HPP
|
||||||
#define CK_AMD_BUFFER_ADDRESSING_V2_HPP
|
#define CK_AMD_BUFFER_ADDRESSING_HPP
|
||||||
|
|
||||||
#include "data_type.hpp"
|
#include "data_type.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
union BufferResource_v2
|
union BufferResource
|
||||||
{
|
{
|
||||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||||
int32x4_t data;
|
int32x4_t content;
|
||||||
StaticallyIndexedArray<T*, 2> address;
|
StaticallyIndexedArray<T*, 2> address;
|
||||||
StaticallyIndexedArray<int32_t, 4> range;
|
StaticallyIndexedArray<int32_t, 4> range;
|
||||||
StaticallyIndexedArray<int32_t, 4> config;
|
StaticallyIndexedArray<int32_t, 4> config;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size)
|
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
|
||||||
{
|
{
|
||||||
BufferResource_v2<T> wave_buffer_resource;
|
BufferResource<T> wave_buffer_resource;
|
||||||
|
|
||||||
// wavewise base address (64 bit)
|
// wavewise base address (64 bit)
|
||||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||||
// wavewise range (32 bit)
|
// wavewise range (32 bit)
|
||||||
wave_buffer_resource.range(Number<2>{}) = data_space_size * sizeof(T);
|
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
|
||||||
// wavewise setting (32 bit)
|
// wavewise setting (32 bit)
|
||||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||||
|
|
||||||
return wave_buffer_resource.data;
|
return wave_buffer_resource.content;
|
||||||
}
|
}
|
||||||
|
|
||||||
// load
|
// load
|
||||||
@@ -204,10 +204,9 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
|||||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||||
|
|
||||||
template <typename T, index_t N>
|
template <typename T, index_t N>
|
||||||
__device__ typename vector_type<T, N>::type
|
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||||
amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
index_t src_thread_addr_offset,
|
||||||
index_t src_thread_addr_offset,
|
index_t src_wave_addr_offset)
|
||||||
index_t src_wave_addr_offset)
|
|
||||||
{
|
{
|
||||||
static_assert(
|
static_assert(
|
||||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||||
@@ -412,10 +411,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, index_t N>
|
template <typename T, index_t N>
|
||||||
__device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type src_thread_data,
|
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||||
int32x4_t dst_wave_buffer_resource,
|
int32x4_t dst_wave_buffer_resource,
|
||||||
index_t dst_thread_addr_offset,
|
index_t dst_thread_addr_offset,
|
||||||
index_t dst_wave_addr_offset)
|
index_t dst_wave_addr_offset)
|
||||||
{
|
{
|
||||||
static_assert(
|
static_assert(
|
||||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||||
@@ -584,67 +583,95 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
|||||||
|
|
||||||
// buffer_load requires:
|
// buffer_load requires:
|
||||||
// 1) p_src_wave must be in global memory space
|
// 1) p_src_wave must be in global memory space
|
||||||
// 2) p_src_wave to be a wavewise pointer.
|
// 2) p_src_wave must be a wavewise pointer.
|
||||||
// It is user's responsibility to make sure that is true.
|
// It is user's responsibility to make sure that is true.
|
||||||
template <typename T, index_t N>
|
template <typename T, index_t N>
|
||||||
__device__ typename vector_type_maker<T, N>::type::type
|
__device__ typename vector_type_maker<T, N>::type::type
|
||||||
amd_buffer_load_v2(const T* p_src_wave,
|
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
|
||||||
index_t src_thread_data_offset,
|
index_t src_thread_element_offset,
|
||||||
bool src_thread_data_valid,
|
bool src_thread_element_valid,
|
||||||
index_t src_element_space)
|
index_t src_element_space_size)
|
||||||
{
|
{
|
||||||
const int32x4_t src_wave_buffer_resource =
|
const int32x4_t src_wave_buffer_resource =
|
||||||
make_wave_buffer_resource(p_src_wave, src_element_space);
|
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||||
|
|
||||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T);
|
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||||
|
|
||||||
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||||
|
using scalar_t = typename scalar_type<vector_t>::type;
|
||||||
|
|
||||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
|
||||||
using scalar_t = typename scalar_type<vector_t>::type;
|
|
||||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||||
|
|
||||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||||
uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff;
|
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
|
||||||
|
|
||||||
return amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
return amd_buffer_load_impl<scalar_t, vector_size>(
|
||||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||||
#else
|
#else
|
||||||
vector_t tmp = amd_buffer_load_impl_v2<scalar_t, vector_size>(
|
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||||
|
|
||||||
return src_thread_data_valid ? tmp : vector_t(0);
|
return src_thread_element_valid ? tmp : vector_t(0);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buffer_load requires:
|
||||||
|
// 1) p_src_wave must be in global memory space
|
||||||
|
// 2) p_src_wave must be a wavewise pointer.
|
||||||
|
// It is user's responsibility to make sure that is true.
|
||||||
|
template <typename T, index_t N>
|
||||||
|
__device__ typename vector_type_maker<T, N>::type::type
|
||||||
|
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||||
|
index_t src_thread_element_offset,
|
||||||
|
bool src_thread_element_valid,
|
||||||
|
index_t src_element_space_size,
|
||||||
|
T customized_value)
|
||||||
|
{
|
||||||
|
const int32x4_t src_wave_buffer_resource =
|
||||||
|
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||||
|
|
||||||
|
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||||
|
|
||||||
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||||
|
using scalar_t = typename scalar_type<vector_t>::type;
|
||||||
|
|
||||||
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||||
|
|
||||||
|
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||||
|
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||||
|
|
||||||
|
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
||||||
|
}
|
||||||
|
|
||||||
// buffer_store requires:
|
// buffer_store requires:
|
||||||
// 1) p_dst_wave must be global memory
|
// 1) p_dst_wave must be global memory
|
||||||
// 2) p_dst_wave to be a wavewise pointer.
|
// 2) p_dst_wave to be a wavewise pointer.
|
||||||
// It is user's responsibility to make sure that is true.
|
// It is user's responsibility to make sure that is true.
|
||||||
template <typename T, index_t N>
|
template <typename T, index_t N>
|
||||||
__device__ void
|
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||||
amd_buffer_store_v2(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
T* p_dst_wave,
|
||||||
T* p_dst_wave,
|
const index_t dst_thread_element_offset,
|
||||||
const index_t dst_thread_data_offset,
|
const bool dst_thread_element_valid,
|
||||||
const bool dst_thread_data_valid,
|
const index_t dst_element_space_size)
|
||||||
const index_t dst_element_space)
|
|
||||||
{
|
{
|
||||||
const int32x4_t dst_wave_buffer_resource =
|
const int32x4_t dst_wave_buffer_resource =
|
||||||
make_wave_buffer_resource(p_dst_wave, dst_element_space);
|
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||||
|
|
||||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T);
|
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||||
|
|
||||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||||
using scalar_t = typename scalar_type<vector_t>::type;
|
using scalar_t = typename scalar_type<vector_t>::type;
|
||||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||||
|
|
||||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||||
uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff;
|
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
|
||||||
|
|
||||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||||
#else
|
#else
|
||||||
if(dst_thread_data_valid)
|
if(dst_thread_element_valid)
|
||||||
{
|
{
|
||||||
amd_buffer_store_impl_v2<scalar_t, vector_size>(
|
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -1,188 +0,0 @@
|
|||||||
#ifndef CK_AMD_DLOP_HPP
|
|
||||||
#define CK_AMD_DLOP_HPP
|
|
||||||
|
|
||||||
#include "data_type.hpp"
|
|
||||||
|
|
||||||
namespace ck {
|
|
||||||
|
|
||||||
template <typename TA, typename TB, typename TC>
|
|
||||||
__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ void
|
|
||||||
amd_inner_product_dlop<float, float, float>(const float& a, const float& b, float& c)
|
|
||||||
{
|
|
||||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
|
||||||
asm volatile("\n \
|
|
||||||
v_fmac_f32 %0, %1, %2 \n \
|
|
||||||
"
|
|
||||||
: "=v"(c)
|
|
||||||
: "v"(a), "v"(b), "0"(c));
|
|
||||||
#else
|
|
||||||
c += a * b;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ void
|
|
||||||
amd_inner_product_dlop<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I0],
|
|
||||||
vector_type<float, 2>{b}.AsType<float>()[I0],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<float, 2>{a}.AsType<float>()[I1],
|
|
||||||
vector_type<float, 2>{b}.AsType<float>()[I1],
|
|
||||||
c);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ void
|
|
||||||
amd_inner_product_dlop<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I0],
|
|
||||||
vector_type<float, 4>{b}.AsType<float>()[I0],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I1],
|
|
||||||
vector_type<float, 4>{b}.AsType<float>()[I1],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I2],
|
|
||||||
vector_type<float, 4>{b}.AsType<float>()[I2],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<float, 4>{a}.AsType<float>()[I3],
|
|
||||||
vector_type<float, 4>{b}.AsType<float>()[I3],
|
|
||||||
c);
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CK_USE_AMD_DLOP
|
|
||||||
template <>
|
|
||||||
__device__ void
|
|
||||||
amd_inner_product_dlop<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
|
|
||||||
{
|
|
||||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
|
||||||
asm volatile("\n \
|
|
||||||
v_dot2_f32_f16 %0, %1, %2, %0\n \
|
|
||||||
"
|
|
||||||
: "=v"(c)
|
|
||||||
: "v"(a), "v"(b), "0"(c));
|
|
||||||
#else
|
|
||||||
c = __builtin_amdgcn_sdot2(a, b, c, false);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ void
|
|
||||||
amd_inner_product_dlop<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
|
|
||||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
|
|
||||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
|
|
||||||
c);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ void
|
|
||||||
amd_inner_product_dlop<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
|
|
||||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
|
|
||||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
|
|
||||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
|
|
||||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
|
|
||||||
c);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ void amd_inner_product_dlop<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a,
|
|
||||||
const int8x4_t& b,
|
|
||||||
int32_t& c)
|
|
||||||
{
|
|
||||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
|
||||||
asm volatile("\n \
|
|
||||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
|
||||||
"
|
|
||||||
: "=v"(c)
|
|
||||||
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
|
||||||
#else
|
|
||||||
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ void amd_inner_product_dlop<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a,
|
|
||||||
const int8x8_t& b,
|
|
||||||
int32_t& c)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
|
||||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
|
||||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
|
|
||||||
c);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
__device__ void amd_inner_product_dlop<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a,
|
|
||||||
const int8x16_t& b,
|
|
||||||
int32_t& c)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
|
||||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
|
||||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
|
||||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
|
|
||||||
c);
|
|
||||||
|
|
||||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
|
||||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
|
|
||||||
c);
|
|
||||||
}
|
|
||||||
#endif // CK_USE_AMD_DLOP
|
|
||||||
|
|
||||||
} // namespace ck
|
|
||||||
#endif
|
|
||||||
@@ -2,6 +2,9 @@
|
|||||||
#define CK_AMD_INLINE_ASM_HPP
|
#define CK_AMD_INLINE_ASM_HPP
|
||||||
|
|
||||||
#include "data_type.hpp"
|
#include "data_type.hpp"
|
||||||
|
#include "c_style_pointer_cast.hpp"
|
||||||
|
|
||||||
|
// TODO: deprecate all amd_assembly_outer_product_xxx
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -53,9 +56,9 @@ __device__ void
|
|||||||
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
||||||
{
|
{
|
||||||
// TODO remove pointer casting
|
// TODO remove pointer casting
|
||||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
||||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
||||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
||||||
|
|
||||||
// do dot2 two times
|
// do dot2 two times
|
||||||
asm volatile("\n \
|
asm volatile("\n \
|
||||||
@@ -114,11 +117,11 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
|
|||||||
float& c3)
|
float& c3)
|
||||||
{
|
{
|
||||||
// TODO remove pointer casting
|
// TODO remove pointer casting
|
||||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
||||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
||||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
||||||
const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2);
|
const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
|
||||||
const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3);
|
const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
|
||||||
|
|
||||||
// do dot2 two times
|
// do dot2 two times
|
||||||
asm volatile("\n \
|
asm volatile("\n \
|
||||||
@@ -160,11 +163,11 @@ __device__ void amd_assembly_outer_product_1x4(half8_t a,
|
|||||||
{
|
{
|
||||||
|
|
||||||
// TODO remove pointer casting
|
// TODO remove pointer casting
|
||||||
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
|
const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
|
||||||
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
|
const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
|
||||||
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
|
const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
|
||||||
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
|
const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
|
||||||
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);
|
const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
|
||||||
|
|
||||||
amd_assembly_outer_product_1x4(
|
amd_assembly_outer_product_1x4(
|
||||||
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
|
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
|
||||||
@@ -184,11 +187,11 @@ __device__ void amd_assembly_outer_product_1x4(half16_t a,
|
|||||||
float& c3)
|
float& c3)
|
||||||
{
|
{
|
||||||
// TODO remove pointer casting
|
// TODO remove pointer casting
|
||||||
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
|
const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
|
||||||
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
|
const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
|
||||||
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
|
const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
|
||||||
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
|
const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
|
||||||
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);
|
const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
|
||||||
|
|
||||||
amd_assembly_outer_product_1x4(
|
amd_assembly_outer_product_1x4(
|
||||||
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
|
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
|
||||||
|
|||||||
22
composable_kernel/include/utility/c_style_pointer_cast.hpp
Normal file
22
composable_kernel/include/utility/c_style_pointer_cast.hpp
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
#ifndef CK_C_STYLE_POINTER_CAST_HPP
|
||||||
|
#define CK_C_STYLE_POINTER_CAST_HPP
|
||||||
|
|
||||||
|
#include "type.hpp"
|
||||||
|
#include "enable_if.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
|
||||||
|
template <typename PY,
|
||||||
|
typename PX,
|
||||||
|
typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
|
||||||
|
__host__ __device__ PY c_style_pointer_cast(PX p_x)
|
||||||
|
{
|
||||||
|
#pragma clang diagnostic push
|
||||||
|
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||||
|
#pragma clang diagnostic ignored "-Wcast-align"
|
||||||
|
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
|
||||||
|
#pragma clang diagnostic pop
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ck
|
||||||
|
#endif
|
||||||
@@ -7,13 +7,14 @@
|
|||||||
#include "statically_indexed_array.hpp"
|
#include "statically_indexed_array.hpp"
|
||||||
#include "container_element_picker.hpp"
|
#include "container_element_picker.hpp"
|
||||||
#include "multi_index.hpp"
|
#include "multi_index.hpp"
|
||||||
#include "data_type_enum.hpp"
|
|
||||||
#include "data_type.hpp"
|
#include "data_type.hpp"
|
||||||
#include "data_type_helper.hpp"
|
#include "data_type_enum.hpp"
|
||||||
|
#include "data_type_enum_helper.hpp"
|
||||||
#include "functional.hpp"
|
#include "functional.hpp"
|
||||||
#include "functional2.hpp"
|
#include "functional2.hpp"
|
||||||
#include "functional3.hpp"
|
#include "functional3.hpp"
|
||||||
#include "functional4.hpp"
|
#include "functional4.hpp"
|
||||||
|
#include "enable_if.hpp"
|
||||||
#include "integral_constant.hpp"
|
#include "integral_constant.hpp"
|
||||||
#include "math.hpp"
|
#include "math.hpp"
|
||||||
#include "number.hpp"
|
#include "number.hpp"
|
||||||
@@ -23,21 +24,21 @@
|
|||||||
#include "tuple.hpp"
|
#include "tuple.hpp"
|
||||||
#include "tuple_helper.hpp"
|
#include "tuple_helper.hpp"
|
||||||
#include "type.hpp"
|
#include "type.hpp"
|
||||||
#include "utility.hpp"
|
|
||||||
#include "magic_division.hpp"
|
#include "magic_division.hpp"
|
||||||
#include "amd_buffer_addressing_v2.hpp"
|
#include "utility.hpp"
|
||||||
|
#include "c_style_pointer_cast.hpp"
|
||||||
|
#include "amd_address_space.hpp"
|
||||||
|
#include "amd_buffer_addressing.hpp"
|
||||||
#include "static_buffer.hpp"
|
#include "static_buffer.hpp"
|
||||||
#include "dynamic_buffer.hpp"
|
#include "dynamic_buffer.hpp"
|
||||||
|
|
||||||
|
#include "inner_product.hpp"
|
||||||
|
|
||||||
// TODO: remove this
|
// TODO: remove this
|
||||||
#if CK_USE_AMD_INLINE_ASM
|
#if CK_USE_AMD_INLINE_ASM
|
||||||
#include "amd_inline_asm.hpp"
|
#include "amd_inline_asm.hpp"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if CK_USE_AMD_DLOP
|
|
||||||
#include "amd_dlop.hpp"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if CK_USE_AMD_XDLOPS
|
#if CK_USE_AMD_XDLOPS
|
||||||
#include "amd_xdlops.hpp"
|
#include "amd_xdlops.hpp"
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -7,19 +7,14 @@
|
|||||||
#endif
|
#endif
|
||||||
#include "bfloat16_dev.hpp"
|
#include "bfloat16_dev.hpp"
|
||||||
|
|
||||||
// address space for kernel parameter
|
// "Constant" address space for kernel parameter
|
||||||
#define CONSTANT __attribute__((address_space(4)))
|
#define CONSTANT __attribute__((address_space(4)))
|
||||||
|
|
||||||
// GPU target
|
// GPU target
|
||||||
// should enable one and only one GPU target
|
// should enable one and only one GPU target
|
||||||
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
|
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
|
||||||
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
|
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
|
||||||
#error Need to define a single GPU target
|
#error Need to define (only) one GPU target
|
||||||
#endif
|
|
||||||
|
|
||||||
// HIP version
|
|
||||||
#ifndef CK_HIP_VERSION_FLAT
|
|
||||||
#define CK_HIP_VERSION_FLAT 0
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// launch bounds
|
// launch bounds
|
||||||
@@ -38,6 +33,16 @@
|
|||||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// FMA instruction
|
||||||
|
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900)
|
||||||
|
#define CK_USE_AMD_V_MAC_F32
|
||||||
|
#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \
|
||||||
|
defined(CK_AMD_GPU_GFX1030)
|
||||||
|
#define CK_USE_AMD_V_FMAC_F32
|
||||||
|
#define CK_USE_AMD_V_DOT2_F32_F16
|
||||||
|
#define CK_USE_AMD_V_DOT4_I32_I8
|
||||||
|
#endif
|
||||||
|
|
||||||
// multi index
|
// multi index
|
||||||
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
|
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
|
||||||
|
|
||||||
@@ -46,13 +51,9 @@
|
|||||||
#define CK_USE_AMD_INLINE_ASM 1
|
#define CK_USE_AMD_INLINE_ASM 1
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// AMD DLOPS
|
// AMD inner product (DLOP)
|
||||||
#ifndef CK_USE_AMD_DLOP
|
#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||||
#define CK_USE_AMD_DLOP 1
|
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef CK_USE_AMD_DLOP_INLINE_ASM
|
|
||||||
#define CK_USE_AMD_DLOP_INLINE_ASM 1
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// AMD buffer addressing
|
// AMD buffer addressing
|
||||||
@@ -99,8 +100,8 @@
|
|||||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||||
// thread-invariant, otherwise it's a bug
|
// thread-invariant, otherwise it's a bug
|
||||||
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
|
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
|
||||||
#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||||
#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
|
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// workaround for compiler crash when compiling recursive lambda
|
// workaround for compiler crash when compiling recursive lambda
|
||||||
@@ -120,15 +121,6 @@
|
|||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
enum AddressSpaceEnum_t
|
|
||||||
{
|
|
||||||
Generic,
|
|
||||||
Global,
|
|
||||||
Lds,
|
|
||||||
Sgpr,
|
|
||||||
Vgpr
|
|
||||||
};
|
|
||||||
|
|
||||||
enum InMemoryDataOperationEnum_t
|
enum InMemoryDataOperationEnum_t
|
||||||
{
|
{
|
||||||
Set,
|
Set,
|
||||||
|
|||||||
@@ -3,8 +3,7 @@
|
|||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
// this enumerate should be synchronized with include/miopen.h
|
enum DataTypeEnum_t
|
||||||
typedef enum
|
|
||||||
{
|
{
|
||||||
Half = 0,
|
Half = 0,
|
||||||
Float = 1,
|
Float = 1,
|
||||||
@@ -14,7 +13,7 @@ typedef enum
|
|||||||
BFloat16 = 5,
|
BFloat16 = 5,
|
||||||
Double = 6,
|
Double = 6,
|
||||||
Unknown = 100,
|
Unknown = 100,
|
||||||
} DataTypeEnum_t;
|
};
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#ifndef CK_DATA_TYPE_HELPER_HPP
|
#ifndef CK_DATA_TYPE_ENUM_HELPER_HPP
|
||||||
#define CK_DATA_TYPE_HELPER_HPP
|
#define CK_DATA_TYPE_ENUM_HELPER_HPP
|
||||||
|
|
||||||
#include "data_type.hpp"
|
#include "data_type.hpp"
|
||||||
#include "data_type_enum.hpp"
|
#include "data_type_enum.hpp"
|
||||||
@@ -1,38 +1,49 @@
|
|||||||
#ifndef CK_DYNAMIC_BUFFER_HPP
|
#ifndef CK_BUFFER_HPP
|
||||||
#define CK_DYNAMIC_BUFFER_HPP
|
#define CK_BUFFER_HPP
|
||||||
|
|
||||||
|
#include "amd_buffer_addressing.hpp"
|
||||||
|
#include "c_style_pointer_cast.hpp"
|
||||||
|
#include "enable_if.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
#include "amd_buffer_addressing_v2.hpp"
|
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||||
|
typename T,
|
||||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
typename ElementSpaceSize,
|
||||||
|
bool InvalidElementUseNumericalZeroValue>
|
||||||
struct DynamicBuffer
|
struct DynamicBuffer
|
||||||
{
|
{
|
||||||
using type = T;
|
using type = T;
|
||||||
|
|
||||||
T* p_data_;
|
T* p_data_;
|
||||||
ElementSpaceSize element_space_size_;
|
ElementSpaceSize element_space_size_;
|
||||||
|
T invalid_element_value_ = T{0};
|
||||||
|
|
||||||
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
||||||
: p_data_{p_data}, element_space_size_{element_space_size}
|
: p_data_{p_data}, element_space_size_{element_space_size}
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__host__ __device__ constexpr DynamicBuffer(T* p_data,
|
||||||
|
ElementSpaceSize element_space_size,
|
||||||
|
T invalid_element_value)
|
||||||
|
: p_data_{p_data},
|
||||||
|
element_space_size_{element_space_size},
|
||||||
|
invalid_element_value_{invalid_element_value}
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||||
{
|
{
|
||||||
return BufferAddressSpace;
|
return BufferAddressSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
|
|
||||||
|
|
||||||
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
|
|
||||||
|
|
||||||
template <typename X,
|
template <typename X,
|
||||||
typename std::enable_if<
|
typename enable_if<
|
||||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const
|
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
|
||||||
{
|
{
|
||||||
// X contains multiple T
|
// X contains multiple T
|
||||||
constexpr index_t scalar_per_t_vector =
|
constexpr index_t scalar_per_t_vector =
|
||||||
@@ -44,29 +55,50 @@ struct DynamicBuffer
|
|||||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||||
"wrong! X need to be multiple T");
|
"wrong! X need to be multiple T");
|
||||||
|
|
||||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
|
||||||
|
|
||||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
|
||||||
{
|
|
||||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||||
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
bool constexpr use_amd_buffer_addressing = true;
|
||||||
p_data_, i, is_valid_offset, element_space_size_);
|
|
||||||
#else
|
#else
|
||||||
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
|
bool constexpr use_amd_buffer_addressing = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
|
||||||
|
{
|
||||||
|
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||||
|
|
||||||
|
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||||
|
{
|
||||||
|
return amd_buffer_load_invalid_element_return_return_zero<
|
||||||
|
remove_cv_t<remove_reference_t<T>>,
|
||||||
|
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return amd_buffer_load_invalid_element_return_customized_value<
|
||||||
|
remove_cv_t<remove_reference_t<T>>,
|
||||||
|
t_per_x>(
|
||||||
|
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return is_valid_offset ? *reinterpret_cast<const X*>(&p_data_[i]) : X{0};
|
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||||
|
{
|
||||||
|
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
|
||||||
|
: X{invalid_element_value_};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X,
|
template <typename X,
|
||||||
typename std::enable_if<
|
typename enable_if<
|
||||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
__host__ __device__ void Set(index_t i, bool is_valid_offset, const X& x)
|
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
|
||||||
{
|
{
|
||||||
// X contains multiple T
|
// X contains multiple T
|
||||||
constexpr index_t scalar_per_t_vector =
|
constexpr index_t scalar_per_t_vector =
|
||||||
@@ -78,26 +110,26 @@ struct DynamicBuffer
|
|||||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||||
"wrong! X need to be multiple T");
|
"wrong! X need to be multiple T");
|
||||||
|
|
||||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
|
||||||
|
|
||||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||||
{
|
{
|
||||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||||
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||||
x, p_data_, i, is_valid_offset, element_space_size_);
|
|
||||||
|
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||||
|
x, p_data_, i, is_valid_element, element_space_size_);
|
||||||
#else
|
#else
|
||||||
if(is_valid_offset)
|
if(is_valid_element)
|
||||||
{
|
{
|
||||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
|
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
|
||||||
{
|
{
|
||||||
if(is_valid_offset)
|
if(is_valid_element)
|
||||||
{
|
{
|
||||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||||
#else
|
#else
|
||||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
|
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
|
||||||
// inefficient
|
// inefficient
|
||||||
@@ -128,24 +160,24 @@ struct DynamicBuffer
|
|||||||
{
|
{
|
||||||
// HACK: cast pointer of x is bad
|
// HACK: cast pointer of x is bad
|
||||||
// TODO: remove this after compiler fix
|
// TODO: remove this after compiler fix
|
||||||
*reinterpret_cast<int8_t*>(&p_data_[i]) =
|
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
|
||||||
*reinterpret_cast<const int8_t*>(&x);
|
*c_style_pointer_cast<const int8_t*>(&x);
|
||||||
}
|
}
|
||||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
|
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
|
||||||
{
|
{
|
||||||
// HACK: cast pointer of x is bad
|
// HACK: cast pointer of x is bad
|
||||||
// TODO: remove this after compiler fix
|
// TODO: remove this after compiler fix
|
||||||
*reinterpret_cast<int16_t*>(&p_data_[i]) =
|
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
|
||||||
*reinterpret_cast<const int16_t*>(&x);
|
*c_style_pointer_cast<const int16_t*>(&x);
|
||||||
}
|
}
|
||||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||||
{
|
{
|
||||||
// HACK: cast pointer of x is bad
|
// HACK: cast pointer of x is bad
|
||||||
// TODO: remove this after compiler fix
|
// TODO: remove this after compiler fix
|
||||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||||
*reinterpret_cast<const int32_t*>(&x);
|
*c_style_pointer_cast<const int32_t*>(&x);
|
||||||
}
|
}
|
||||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||||
int8x4_t>::value &&
|
int8x4_t>::value &&
|
||||||
@@ -153,8 +185,8 @@ struct DynamicBuffer
|
|||||||
{
|
{
|
||||||
// HACK: cast pointer of x is bad
|
// HACK: cast pointer of x is bad
|
||||||
// TODO: remove this after compiler fix
|
// TODO: remove this after compiler fix
|
||||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||||
*reinterpret_cast<const int32_t*>(&x);
|
*c_style_pointer_cast<const int32_t*>(&x);
|
||||||
}
|
}
|
||||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||||
int8x8_t>::value &&
|
int8x8_t>::value &&
|
||||||
@@ -162,8 +194,8 @@ struct DynamicBuffer
|
|||||||
{
|
{
|
||||||
// HACK: cast pointer of x is bad
|
// HACK: cast pointer of x is bad
|
||||||
// TODO: remove this after compiler fix
|
// TODO: remove this after compiler fix
|
||||||
*reinterpret_cast<int32x2_t*>(&p_data_[i]) =
|
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||||
*reinterpret_cast<const int32x2_t*>(&x);
|
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||||
}
|
}
|
||||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||||
int8x16_t>::value &&
|
int8x16_t>::value &&
|
||||||
@@ -171,22 +203,22 @@ struct DynamicBuffer
|
|||||||
{
|
{
|
||||||
// HACK: cast pointer of x is bad
|
// HACK: cast pointer of x is bad
|
||||||
// TODO: remove this after compiler fix
|
// TODO: remove this after compiler fix
|
||||||
*reinterpret_cast<int32x4_t*>(&p_data_[i]) =
|
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
|
||||||
*reinterpret_cast<const int32x4_t*>(&x);
|
*c_style_pointer_cast<const int32x4_t*>(&x);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if(is_valid_offset)
|
if(is_valid_element)
|
||||||
{
|
{
|
||||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -196,12 +228,18 @@ struct DynamicBuffer
|
|||||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||||
typename T,
|
|
||||||
typename ElementSpaceSize>
|
|
||||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
||||||
{
|
{
|
||||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize>{p, element_space_size};
|
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||||
|
__host__ __device__ constexpr auto
|
||||||
|
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
|
||||||
|
{
|
||||||
|
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
|
||||||
|
p, element_space_size, invalid_element_value};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
|
|||||||
13
composable_kernel/include/utility/enable_if.hpp
Normal file
13
composable_kernel/include/utility/enable_if.hpp
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#ifndef CK_ENABLE_IF_HPP
|
||||||
|
#define CK_ENABLE_IF_HPP
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
|
||||||
|
template <bool B, typename T = void>
|
||||||
|
using enable_if = std::enable_if<B, T>;
|
||||||
|
|
||||||
|
template <bool B, typename T = void>
|
||||||
|
using enable_if_t = typename std::enable_if<B, T>::type;
|
||||||
|
|
||||||
|
} // namespace ck
|
||||||
|
#endif
|
||||||
207
composable_kernel/include/utility/inner_product.hpp
Normal file
207
composable_kernel/include/utility/inner_product.hpp
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
#ifndef CK_INNER_PRODUCT_HPP
|
||||||
|
#define CK_INNER_PRODUCT_HPP
|
||||||
|
|
||||||
|
#include "data_type.hpp"
|
||||||
|
|
||||||
|
namespace ck {
|
||||||
|
|
||||||
|
template <typename TA, typename TB, typename TC>
|
||||||
|
__device__ void inner_product(const TA& a, const TB& b, TC& c);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
|
||||||
|
{
|
||||||
|
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
|
||||||
|
asm volatile("\n \
|
||||||
|
v_mac_f32 %0, %1, %2 \n \
|
||||||
|
"
|
||||||
|
: "=v"(c)
|
||||||
|
: "v"(a), "v"(b), "0"(c));
|
||||||
|
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
|
||||||
|
asm volatile("\n \
|
||||||
|
v_fmac_f32 %0, %1, %2 \n \
|
||||||
|
"
|
||||||
|
: "=v"(c)
|
||||||
|
: "v"(a), "v"(b), "0"(c));
|
||||||
|
#else
|
||||||
|
c += a * b;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void
|
||||||
|
inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
|
||||||
|
inner_product(vector_type<float, 2>{a}.AsType<float>()[I0],
|
||||||
|
vector_type<float, 2>{b}.AsType<float>()[I0],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<float, 2>{a}.AsType<float>()[I1],
|
||||||
|
vector_type<float, 2>{b}.AsType<float>()[I1],
|
||||||
|
c);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void
|
||||||
|
inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
constexpr auto I3 = Number<3>{};
|
||||||
|
|
||||||
|
inner_product(vector_type<float, 4>{a}.AsType<float>()[I0],
|
||||||
|
vector_type<float, 4>{b}.AsType<float>()[I0],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<float, 4>{a}.AsType<float>()[I1],
|
||||||
|
vector_type<float, 4>{b}.AsType<float>()[I1],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<float, 4>{a}.AsType<float>()[I2],
|
||||||
|
vector_type<float, 4>{b}.AsType<float>()[I2],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<float, 4>{a}.AsType<float>()[I3],
|
||||||
|
vector_type<float, 4>{b}.AsType<float>()[I3],
|
||||||
|
c);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
|
||||||
|
{
|
||||||
|
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
|
||||||
|
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||||
|
asm volatile("\n \
|
||||||
|
v_dot2_f32_f16 %0, %1, %2, %0\n \
|
||||||
|
"
|
||||||
|
: "=v"(c)
|
||||||
|
: "v"(a), "v"(b), "0"(c));
|
||||||
|
#else
|
||||||
|
c = __builtin_amdgcn_sdot2(a, b, c, false);
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
const auto convert = type_convert<int32_t>{};
|
||||||
|
|
||||||
|
const vector_type<half_t, 2> a_vector{a};
|
||||||
|
const vector_type<half_t, 2> b_vector{b};
|
||||||
|
|
||||||
|
static_for<0, 2, 1>{}([&](auto i) {
|
||||||
|
c += convert(a_vector.AsType<half_t>()[i]) * convert(b_vector.AsType<half_t>()[i]);
|
||||||
|
});
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
|
||||||
|
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
|
||||||
|
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
|
||||||
|
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
|
||||||
|
c);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
constexpr auto I3 = Number<3>{};
|
||||||
|
|
||||||
|
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
|
||||||
|
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
|
||||||
|
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
|
||||||
|
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
|
||||||
|
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
|
||||||
|
c);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void
|
||||||
|
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
|
||||||
|
{
|
||||||
|
#if defined(CK_USE_DOT4_I32_I8)
|
||||||
|
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||||
|
asm volatile("\n \
|
||||||
|
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||||
|
"
|
||||||
|
: "=v"(c)
|
||||||
|
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
||||||
|
#else
|
||||||
|
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
const auto convert = type_convert<int32_t>{};
|
||||||
|
|
||||||
|
const vector_type<int8_t, 4> a_vector{a};
|
||||||
|
const vector_type<int8_t, 4> b_vector{b};
|
||||||
|
|
||||||
|
static_for<0, 4, 1>{}([&](auto i) {
|
||||||
|
c += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||||
|
});
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void
|
||||||
|
inner_product<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a, const int8x8_t& b, int32_t& c)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
|
||||||
|
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||||
|
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||||
|
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
|
||||||
|
c);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ void
|
||||||
|
inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t& b, int32_t& c)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
constexpr auto I3 = Number<3>{};
|
||||||
|
|
||||||
|
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||||
|
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||||
|
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||||
|
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
|
||||||
|
c);
|
||||||
|
|
||||||
|
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||||
|
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
|
||||||
|
c);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ck
|
||||||
|
#endif
|
||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "integral_constant.hpp"
|
#include "integral_constant.hpp"
|
||||||
#include "number.hpp"
|
#include "number.hpp"
|
||||||
#include "type.hpp"
|
#include "type.hpp"
|
||||||
|
#include "enable_if.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
namespace math {
|
namespace math {
|
||||||
@@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
|
|||||||
return Number<r>{};
|
return Number<r>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X,
|
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||||
typename... Ys,
|
|
||||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
|
||||||
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
||||||
{
|
{
|
||||||
return gcd(x, gcd(ys...));
|
return gcd(x, gcd(ys...));
|
||||||
@@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
|
|||||||
return (x * y) / gcd(x, y);
|
return (x * y) / gcd(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X,
|
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||||
typename... Ys,
|
|
||||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
|
||||||
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
|
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
|
||||||
{
|
{
|
||||||
return lcm(x, lcm(ys...));
|
return lcm(x, lcm(ys...));
|
||||||
|
|||||||
@@ -11,59 +11,11 @@ namespace ck {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__host__ __device__ void print_array(const char* s, T a)
|
__host__ __device__ void print_array(const char* s, T a)
|
||||||
{
|
{
|
||||||
using data_type = decltype(a.At(Number<0>{}));
|
|
||||||
constexpr index_t nsize = a.Size();
|
constexpr index_t nsize = a.Size();
|
||||||
|
|
||||||
#if 0
|
|
||||||
if constexpr(is_same<data_type, uint32_t>{})
|
|
||||||
{
|
|
||||||
printf("%s size %u, {", s, nsize);
|
|
||||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
|
|
||||||
printf("}\n");
|
|
||||||
}
|
|
||||||
else if constexpr(is_same<data_type, int32_t>{})
|
|
||||||
{
|
|
||||||
printf("%s size %d, {", s, nsize);
|
|
||||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
|
||||||
printf("}\n");
|
|
||||||
}
|
|
||||||
else if constexpr(is_same<data_type, bool>{})
|
|
||||||
{
|
|
||||||
printf("%s size %d, {", s, nsize);
|
|
||||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
|
|
||||||
printf("}\n");
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
printf("%s size %d, {", s, nsize);
|
printf("%s size %d, {", s, nsize);
|
||||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
||||||
printf("}\n");
|
printf("}\n");
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__host__ __device__ void print_array_v2(const char* s, T a)
|
|
||||||
{
|
|
||||||
using data_type = decltype(a.At(Number<0>{}));
|
|
||||||
constexpr index_t nsize = a.Size();
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
if constexpr(is_same<data_type, uint32_t>{})
|
|
||||||
{
|
|
||||||
printf("%s size %u, {", s, nsize);
|
|
||||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
|
|
||||||
printf("}\n");
|
|
||||||
}
|
|
||||||
else if constexpr(is_same<data_type, int32_t>{})
|
|
||||||
{
|
|
||||||
printf("%s size %d, {", s, nsize);
|
|
||||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
|
|
||||||
printf("}\n");
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
printf("%s size %d, {", s, nsize);
|
|
||||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
|
|
||||||
printf("}\n");
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
|
|||||||
@@ -685,8 +685,6 @@ __host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
|
|||||||
template <index_t Y, index_t... Xs>
|
template <index_t Y, index_t... Xs>
|
||||||
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
|
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
|
||||||
{
|
{
|
||||||
constexpr auto seq_x = Sequence<Xs...>{};
|
|
||||||
|
|
||||||
return Sequence<(Y - Xs)...>{};
|
return Sequence<(Y - Xs)...>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,30 +5,66 @@
|
|||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||||
|
typename T,
|
||||||
|
index_t N,
|
||||||
|
bool InvalidElementUseNumericalZeroValue>
|
||||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||||
{
|
{
|
||||||
using type = T;
|
using type = T;
|
||||||
using base = StaticallyIndexedArray<T, N>;
|
using base = StaticallyIndexedArray<T, N>;
|
||||||
|
|
||||||
|
T invalid_element_value_ = T{0};
|
||||||
|
|
||||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||||
|
|
||||||
|
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
|
||||||
|
: base{}, invalid_element_value_{invalid_element_value}
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||||
{
|
{
|
||||||
return BufferAddressSpace;
|
return BufferAddressSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <index_t I>
|
||||||
|
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
|
||||||
|
{
|
||||||
|
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||||
|
{
|
||||||
|
return is_valid_element ? At(i) : T{0};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return is_valid_element ? At(i) : invalid_element_value_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <index_t I>
|
||||||
|
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
|
||||||
|
{
|
||||||
|
if(is_valid_element)
|
||||||
|
{
|
||||||
|
At(i) = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||||
|
|
||||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||||
typename T,
|
|
||||||
index_t N>
|
|
||||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||||
{
|
{
|
||||||
return StaticBuffer<BufferAddressSpace, T, N>{};
|
return StaticBuffer<BufferAddressSpace, T, N, true>{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||||
|
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
|
||||||
|
{
|
||||||
|
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ck
|
} // namespace ck
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "integral_constant.hpp"
|
#include "integral_constant.hpp"
|
||||||
#include "sequence.hpp"
|
#include "sequence.hpp"
|
||||||
#include "type.hpp"
|
#include "type.hpp"
|
||||||
|
#include "enable_if.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -20,10 +21,9 @@ struct TupleElement
|
|||||||
{
|
{
|
||||||
__host__ __device__ constexpr TupleElement() = default;
|
__host__ __device__ constexpr TupleElement() = default;
|
||||||
|
|
||||||
template <
|
template <typename T,
|
||||||
typename T,
|
typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
|
||||||
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
|
bool>::type = false>
|
||||||
bool>::type = false>
|
|
||||||
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
|
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
|
|||||||
{
|
{
|
||||||
__host__ __device__ constexpr TupleImpl() = default;
|
__host__ __device__ constexpr TupleImpl() = default;
|
||||||
|
|
||||||
template <
|
template <typename Y,
|
||||||
typename Y,
|
typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
|
||||||
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
|
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
|
||||||
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
|
bool>::type = false>
|
||||||
bool>::type = false>
|
|
||||||
__host__ __device__ constexpr TupleImpl(Y&& y)
|
__host__ __device__ constexpr TupleImpl(Y&& y)
|
||||||
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
|
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||||
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
|
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
|
||||||
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
|
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
|
||||||
{
|
{
|
||||||
@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
|||||||
__host__ __device__ constexpr Tuple() = default;
|
__host__ __device__ constexpr Tuple() = default;
|
||||||
|
|
||||||
template <typename Y,
|
template <typename Y,
|
||||||
typename std::enable_if<
|
typename enable_if<sizeof...(Xs) == 1 &&
|
||||||
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
|
!is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
|
||||||
bool>::type = false>
|
bool>::type = false>
|
||||||
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
|
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Ys,
|
template <typename... Ys,
|
||||||
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2,
|
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
|
||||||
bool>::type = false>
|
false>
|
||||||
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
|
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
#define CK_TYPE_HPP
|
#define CK_TYPE_HPP
|
||||||
|
|
||||||
#include "integral_constant.hpp"
|
#include "integral_constant.hpp"
|
||||||
|
#include "enable_if.hpp"
|
||||||
|
|
||||||
namespace ck {
|
namespace ck {
|
||||||
|
|
||||||
@@ -22,10 +23,7 @@ template <typename T>
|
|||||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
|
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||||
{
|
|
||||||
return static_cast<typename std::remove_reference<T>::type&&>(t);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct is_known_at_compile_time;
|
struct is_known_at_compile_time;
|
||||||
@@ -42,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>>
|
|||||||
static constexpr bool value = true;
|
static constexpr bool value = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Y,
|
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||||
typename X,
|
|
||||||
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
|
||||||
__host__ __device__ constexpr Y as_type(X x)
|
__host__ __device__ constexpr Y as_type(X x)
|
||||||
{
|
{
|
||||||
union AsType
|
union AsType
|
||||||
|
|||||||
@@ -0,0 +1,370 @@
|
|||||||
|
#include "common_header.hpp"
|
||||||
|
#include "tensor_descriptor.hpp"
|
||||||
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
#include "gridwise_gemm_dlops_v1r2.hpp"
|
||||||
|
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||||
|
|
||||||
|
using namespace ck;
|
||||||
|
|
||||||
|
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||||
|
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||||
|
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||||
|
|
||||||
|
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||||
|
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||||
|
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||||
|
|
||||||
|
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||||
|
|
||||||
|
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||||
|
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||||
|
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||||
|
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
|
||||||
|
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
|
||||||
|
constexpr index_t KPerThread = CK_PARAM_KPerThread;
|
||||||
|
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
|
||||||
|
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
|
||||||
|
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
|
||||||
|
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
|
||||||
|
|
||||||
|
using ABlockTransferThreadSliceLengths_K_M0_M1 =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1>;
|
||||||
|
using ABlockTransferThreadClusterLengths_K_M0_M1 =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1>;
|
||||||
|
using ABlockTransferThreadClusterArrangeOrder =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||||
|
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||||
|
|
||||||
|
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||||
|
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||||
|
constexpr index_t ABlockTransferDstScalarPerVector_M1 =
|
||||||
|
CK_PARAM_ABlockTransferDstScalarPerVector_M1;
|
||||||
|
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||||
|
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||||
|
|
||||||
|
using BBlockTransferThreadSliceLengths_K_N0_N1 =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1>;
|
||||||
|
using BBlockTransferThreadClusterLengths_K_N0_N1 =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1>;
|
||||||
|
using BBlockTransferThreadClusterArrangeOrder =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||||
|
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||||
|
|
||||||
|
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||||
|
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||||
|
constexpr index_t BBlockTransferDstScalarPerVector_N1 =
|
||||||
|
CK_PARAM_BBlockTransferDstScalarPerVector_N1;
|
||||||
|
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||||
|
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||||
|
|
||||||
|
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||||
|
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||||
|
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||||
|
|
||||||
|
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||||
|
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||||
|
|
||||||
|
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
|
||||||
|
int n,
|
||||||
|
int c,
|
||||||
|
int hi,
|
||||||
|
int wi,
|
||||||
|
int k,
|
||||||
|
int y,
|
||||||
|
int x,
|
||||||
|
int convStrideH,
|
||||||
|
int convStrideW,
|
||||||
|
int convDilationY,
|
||||||
|
int convDilationX,
|
||||||
|
int leftPadH,
|
||||||
|
int leftPadW,
|
||||||
|
int rightPadH,
|
||||||
|
int rightPadW,
|
||||||
|
void* p_a_k_m0_m1_grid_desc,
|
||||||
|
void* p_b_k_n0_n1_grid_desc,
|
||||||
|
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
|
||||||
|
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||||
|
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||||
|
|
||||||
|
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
|
||||||
|
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
|
||||||
|
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
|
||||||
|
|
||||||
|
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||||
|
wei_k_c_y_x_desc,
|
||||||
|
in_n_c_hi_wi_desc,
|
||||||
|
out_n_k_ho_wo_desc,
|
||||||
|
make_tuple(convStrideH, convStrideW),
|
||||||
|
make_tuple(convDilationY, convDilationX),
|
||||||
|
make_tuple(leftPadH, leftPadW),
|
||||||
|
make_tuple(rightPadH, rightPadW));
|
||||||
|
|
||||||
|
const auto a_k_m_grid_desc = descs[I0];
|
||||||
|
const auto b_k_n_grid_desc = descs[I1];
|
||||||
|
const auto c_m_n_grid_desc = descs[I2];
|
||||||
|
|
||||||
|
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||||
|
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||||
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using BGridStepHacks =
|
||||||
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
|
||||||
|
using GridwiseGemm =
|
||||||
|
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||||
|
AKMGridDesc,
|
||||||
|
BKNGridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
M1PerThread,
|
||||||
|
N1PerThread,
|
||||||
|
KPerThread,
|
||||||
|
M1N1ThreadClusterM10,
|
||||||
|
M1N1ThreadClusterN10,
|
||||||
|
M1N1ThreadClusterM11,
|
||||||
|
M1N1ThreadClusterN11,
|
||||||
|
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||||
|
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_M1,
|
||||||
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||||
|
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_N1,
|
||||||
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
|
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||||
|
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||||
|
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||||
|
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||||
|
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
|
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
if(hipThreadIdx_x == 0)
|
||||||
|
{
|
||||||
|
*static_cast<decltype(a_k_m0_m1_grid_desc)*>(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc;
|
||||||
|
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
|
||||||
|
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
|
||||||
|
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||||
|
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||||
|
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
extern "C" __global__ void
|
||||||
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
|
#endif
|
||||||
|
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||||
|
const FloatAB* __restrict__ p_a_grid,
|
||||||
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
|
FloatC* __restrict__ p_c_grid,
|
||||||
|
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||||
|
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||||
|
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
|
||||||
|
constexpr auto in_n_c_hi_wi_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||||
|
constexpr auto wei_k_c_y_x_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||||
|
constexpr auto out_n_k_ho_wo_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||||
|
|
||||||
|
constexpr auto descs =
|
||||||
|
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||||
|
in_n_c_hi_wi_desc,
|
||||||
|
out_n_k_ho_wo_desc,
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1));
|
||||||
|
|
||||||
|
constexpr auto a_k_m_grid_desc = descs[I0];
|
||||||
|
constexpr auto b_k_n_grid_desc = descs[I1];
|
||||||
|
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||||
|
|
||||||
|
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||||
|
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||||
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using BGridStepHacks =
|
||||||
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
|
||||||
|
using GridwiseGemm =
|
||||||
|
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||||
|
AKMGridDesc,
|
||||||
|
BKNGridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
M1PerThread,
|
||||||
|
N1PerThread,
|
||||||
|
KPerThread,
|
||||||
|
M1N1ThreadClusterM10,
|
||||||
|
M1N1ThreadClusterN10,
|
||||||
|
M1N1ThreadClusterM11,
|
||||||
|
M1N1ThreadClusterN11,
|
||||||
|
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||||
|
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_M1,
|
||||||
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||||
|
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_N1,
|
||||||
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
|
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
||||||
|
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||||
|
constexpr auto b_k_n0_n1_grid_desc_tmp =
|
||||||
|
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||||
|
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
|
||||||
|
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||||
|
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||||
|
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
|
||||||
|
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
|
||||||
|
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
|
||||||
|
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||||
|
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||||
|
|
||||||
|
const auto a_k_m0_m1_grid_desc =
|
||||||
|
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
|
||||||
|
const auto b_k_n0_n1_grid_desc =
|
||||||
|
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
|
||||||
|
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||||
|
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||||
|
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||||
|
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
|
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||||
|
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
|
||||||
|
constexpr index_t shared_block_size =
|
||||||
|
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||||
|
|
||||||
|
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||||
|
|
||||||
|
GridwiseGemm::Run(p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
p_shared_block,
|
||||||
|
a_k_m0_m1_grid_desc,
|
||||||
|
b_k_n0_n1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||||
|
integral_constant<bool, HasMainKBlockLoop>{},
|
||||||
|
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||||
|
};
|
||||||
@@ -0,0 +1,358 @@
|
|||||||
|
#include "common_header.hpp"
|
||||||
|
#include "tensor_descriptor.hpp"
|
||||||
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||||
|
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||||
|
|
||||||
|
using namespace ck;
|
||||||
|
|
||||||
|
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||||
|
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||||
|
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||||
|
|
||||||
|
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||||
|
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||||
|
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||||
|
|
||||||
|
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||||
|
|
||||||
|
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||||
|
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||||
|
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||||
|
|
||||||
|
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||||
|
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||||
|
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||||
|
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||||
|
constexpr index_t K1 = CK_PARAM_K1;
|
||||||
|
|
||||||
|
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||||
|
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||||
|
using ABlockTransferThreadClusterArrangeOrder =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||||
|
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||||
|
|
||||||
|
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||||
|
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||||
|
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||||
|
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||||
|
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||||
|
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||||
|
|
||||||
|
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||||
|
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||||
|
using BBlockTransferThreadClusterArrangeOrder =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||||
|
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||||
|
|
||||||
|
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||||
|
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||||
|
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||||
|
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||||
|
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||||
|
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||||
|
|
||||||
|
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||||
|
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||||
|
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||||
|
|
||||||
|
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
|
||||||
|
int n,
|
||||||
|
int c,
|
||||||
|
int hi,
|
||||||
|
int wi,
|
||||||
|
int k,
|
||||||
|
int y,
|
||||||
|
int x,
|
||||||
|
int convStrideH,
|
||||||
|
int convStrideW,
|
||||||
|
int convDilationY,
|
||||||
|
int convDilationX,
|
||||||
|
int leftPadH,
|
||||||
|
int leftPadW,
|
||||||
|
int rightPadH,
|
||||||
|
int rightPadW,
|
||||||
|
void* p_a_k0_m_k1_grid_desc,
|
||||||
|
void* p_b_k0_n_k1_grid_desc,
|
||||||
|
void* p_c_m0_m1_m2_n_grid_desc,
|
||||||
|
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
|
||||||
|
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||||
|
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||||
|
|
||||||
|
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
|
||||||
|
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
|
||||||
|
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
|
||||||
|
|
||||||
|
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||||
|
wei_k_c_y_x_desc,
|
||||||
|
in_n_c_hi_wi_desc,
|
||||||
|
out_n_k_ho_wo_desc,
|
||||||
|
make_tuple(convStrideH, convStrideW),
|
||||||
|
make_tuple(convDilationY, convDilationX),
|
||||||
|
make_tuple(leftPadH, leftPadW),
|
||||||
|
make_tuple(rightPadH, rightPadW),
|
||||||
|
Number<K1>{});
|
||||||
|
|
||||||
|
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||||
|
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||||
|
const auto c_m_n_grid_desc = descs[I2];
|
||||||
|
|
||||||
|
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||||
|
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||||
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using AGridStepHacks = decltype(make_tuple(
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
|
make_tuple(
|
||||||
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using BGridStepHacks =
|
||||||
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
|
||||||
|
using GridwiseGemm =
|
||||||
|
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
InMemoryDataOperationEnum_t::Set,
|
||||||
|
AK0MK1GridDesc,
|
||||||
|
BK0NK1GridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
MPerWave,
|
||||||
|
NPerWave,
|
||||||
|
K1,
|
||||||
|
MRepeat,
|
||||||
|
NRepeat,
|
||||||
|
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_K1,
|
||||||
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_K1,
|
||||||
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
|
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
if(hipThreadIdx_x == 0)
|
||||||
|
{
|
||||||
|
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||||
|
a_k0_m_k1_grid_desc;
|
||||||
|
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||||
|
b_k0_n_k1_grid_desc;
|
||||||
|
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||||
|
c_m0_m1_m2_n_grid_desc;
|
||||||
|
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||||
|
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
extern "C" __global__ void
|
||||||
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
|
#endif
|
||||||
|
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||||
|
const FloatAB* __restrict__ p_a_grid,
|
||||||
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
|
FloatC* __restrict__ p_c_grid,
|
||||||
|
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||||
|
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||||
|
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||||
|
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||||
|
{
|
||||||
|
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
|
||||||
|
constexpr auto in_n_c_hi_wi_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||||
|
constexpr auto wei_k_c_y_x_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||||
|
constexpr auto out_n_k_ho_wo_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||||
|
|
||||||
|
constexpr auto descs =
|
||||||
|
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||||
|
in_n_c_hi_wi_desc,
|
||||||
|
out_n_k_ho_wo_desc,
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1),
|
||||||
|
Number<K1>{});
|
||||||
|
|
||||||
|
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||||
|
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||||
|
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||||
|
|
||||||
|
using AGridStepHacks = decltype(make_tuple(
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
|
make_tuple(
|
||||||
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using BGridStepHacks =
|
||||||
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
|
||||||
|
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||||
|
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||||
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using GridwiseGemm =
|
||||||
|
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
InMemoryDataOperationEnum_t::Set,
|
||||||
|
AK0MK1GridDesc,
|
||||||
|
BK0NK1GridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
MPerWave,
|
||||||
|
NPerWave,
|
||||||
|
K1,
|
||||||
|
MRepeat,
|
||||||
|
NRepeat,
|
||||||
|
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_K1,
|
||||||
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_K1,
|
||||||
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||||
|
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||||
|
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||||
|
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||||
|
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||||
|
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||||
|
|
||||||
|
const auto a_k0_m_k1_grid_desc =
|
||||||
|
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||||
|
const auto b_k0_n_k1_grid_desc =
|
||||||
|
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||||
|
const auto c_m0_m1_m2_n_grid_desc =
|
||||||
|
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||||
|
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
|
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||||
|
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
|
||||||
|
constexpr index_t shared_block_size =
|
||||||
|
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||||
|
|
||||||
|
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||||
|
|
||||||
|
GridwiseGemm::Run(p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
p_shared_block,
|
||||||
|
a_k0_m_k1_grid_desc,
|
||||||
|
b_k0_n_k1_grid_desc,
|
||||||
|
c_m0_m1_m2_n_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
};
|
||||||
@@ -0,0 +1,357 @@
|
|||||||
|
#include "common_header.hpp"
|
||||||
|
#include "tensor_descriptor.hpp"
|
||||||
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||||
|
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||||
|
|
||||||
|
using namespace ck;
|
||||||
|
|
||||||
|
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||||
|
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||||
|
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||||
|
|
||||||
|
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||||
|
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||||
|
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||||
|
|
||||||
|
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||||
|
|
||||||
|
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||||
|
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||||
|
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||||
|
|
||||||
|
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||||
|
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||||
|
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||||
|
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||||
|
constexpr index_t K1 = CK_PARAM_K1;
|
||||||
|
|
||||||
|
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||||
|
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||||
|
using ABlockTransferThreadClusterArrangeOrder =
|
||||||
|
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||||
|
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||||
|
|
||||||
|
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||||
|
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||||
|
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||||
|
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||||
|
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||||
|
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||||
|
|
||||||
|
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||||
|
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||||
|
using BBlockTransferThreadClusterArrangeOrder =
|
||||||
|
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||||
|
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||||
|
|
||||||
|
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||||
|
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||||
|
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||||
|
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||||
|
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||||
|
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||||
|
|
||||||
|
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||||
|
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||||
|
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||||
|
|
||||||
|
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
|
||||||
|
int n,
|
||||||
|
int hi,
|
||||||
|
int wi,
|
||||||
|
int c,
|
||||||
|
int k,
|
||||||
|
int y,
|
||||||
|
int x,
|
||||||
|
int convStrideH,
|
||||||
|
int convStrideW,
|
||||||
|
int convDilationY,
|
||||||
|
int convDilationX,
|
||||||
|
int leftPadH,
|
||||||
|
int leftPadW,
|
||||||
|
int rightPadH,
|
||||||
|
int rightPadW,
|
||||||
|
void* p_a_k0_m_k1_grid_desc,
|
||||||
|
void* p_b_k0_n_k1_grid_desc,
|
||||||
|
void* p_c_m0_m1_m2_n_grid_desc,
|
||||||
|
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||||
|
{
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
|
||||||
|
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||||
|
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||||
|
|
||||||
|
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c));
|
||||||
|
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c));
|
||||||
|
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k));
|
||||||
|
|
||||||
|
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||||
|
in_n_hi_wi_c_desc,
|
||||||
|
wei_k_y_x_c_desc,
|
||||||
|
out_n_ho_wo_k_desc,
|
||||||
|
make_tuple(convStrideH, convStrideW),
|
||||||
|
make_tuple(convDilationY, convDilationX),
|
||||||
|
make_tuple(leftPadH, leftPadW),
|
||||||
|
make_tuple(rightPadH, rightPadW),
|
||||||
|
Number<K1>{});
|
||||||
|
|
||||||
|
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||||
|
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||||
|
const auto c_m_n_grid_desc = descs[I2];
|
||||||
|
|
||||||
|
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||||
|
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||||
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using BGridStepHacks = decltype(make_tuple(
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
|
make_tuple(
|
||||||
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using AGridStepHacks =
|
||||||
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
|
using GridwiseGemm =
|
||||||
|
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
InMemoryDataOperationEnum_t::Set,
|
||||||
|
AK0MK1GridDesc,
|
||||||
|
BK0NK1GridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
MPerWave,
|
||||||
|
NPerWave,
|
||||||
|
K1,
|
||||||
|
MRepeat,
|
||||||
|
NRepeat,
|
||||||
|
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_K1,
|
||||||
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_K1,
|
||||||
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
|
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
if(hipThreadIdx_x == 0)
|
||||||
|
{
|
||||||
|
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||||
|
a_k0_m_k1_grid_desc;
|
||||||
|
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||||
|
b_k0_n_k1_grid_desc;
|
||||||
|
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||||
|
c_m0_m1_m2_n_grid_desc;
|
||||||
|
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||||
|
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
extern "C" __global__ void
|
||||||
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
|
#endif
|
||||||
|
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
||||||
|
const FloatAB* __restrict__ p_a_grid,
|
||||||
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
|
FloatC* __restrict__ p_c_grid,
|
||||||
|
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||||
|
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||||
|
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||||
|
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||||
|
{
|
||||||
|
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
|
||||||
|
constexpr auto in_n_hi_wi_c_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||||
|
constexpr auto wei_k_y_x_c_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256));
|
||||||
|
constexpr auto out_n_ho_wo_k_desc =
|
||||||
|
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||||
|
|
||||||
|
constexpr auto descs =
|
||||||
|
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||||
|
wei_k_y_x_c_desc,
|
||||||
|
out_n_ho_wo_k_desc,
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1),
|
||||||
|
make_tuple(1, 1),
|
||||||
|
Number<K1>{});
|
||||||
|
|
||||||
|
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||||
|
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||||
|
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||||
|
|
||||||
|
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||||
|
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||||
|
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using BGridStepHacks = decltype(make_tuple(
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
|
make_tuple(
|
||||||
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using AGridStepHacks =
|
||||||
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||||
|
|
||||||
|
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 1, 0, 0>{}),
|
||||||
|
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
|
Sequence<0, 0, 2, 0, 0>{})));
|
||||||
|
|
||||||
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||||
|
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
|
using GridwiseGemm =
|
||||||
|
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
InMemoryDataOperationEnum_t::Set,
|
||||||
|
AK0MK1GridDesc,
|
||||||
|
BK0NK1GridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
MPerWave,
|
||||||
|
NPerWave,
|
||||||
|
K1,
|
||||||
|
MRepeat,
|
||||||
|
NRepeat,
|
||||||
|
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_K1,
|
||||||
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_K1,
|
||||||
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks,
|
||||||
|
false>;
|
||||||
|
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||||
|
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||||
|
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||||
|
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||||
|
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||||
|
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||||
|
|
||||||
|
const auto a_k0_m_k1_grid_desc =
|
||||||
|
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||||
|
const auto b_k0_n_k1_grid_desc =
|
||||||
|
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||||
|
const auto c_m0_m1_m2_n_grid_desc =
|
||||||
|
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||||
|
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
|
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||||
|
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
|
||||||
|
constexpr index_t shared_block_size =
|
||||||
|
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||||
|
|
||||||
|
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||||
|
|
||||||
|
GridwiseGemm::Run(p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
p_shared_block,
|
||||||
|
a_k0_m_k1_grid_desc,
|
||||||
|
b_k0_n_k1_grid_desc,
|
||||||
|
c_m0_m1_m2_n_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
};
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "gridwise_dynamic_contraction_dlops_v1r2.hpp"
|
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||||
|
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
@@ -62,23 +62,39 @@ constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBloc
|
|||||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
|
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
|
||||||
|
|
||||||
extern "C" __global__ void
|
extern "C" __global__ void
|
||||||
dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_,
|
||||||
index_t C,
|
int C_,
|
||||||
index_t Hi,
|
int Hi_,
|
||||||
index_t Wi,
|
int Wi_,
|
||||||
index_t K,
|
int K_,
|
||||||
index_t Y,
|
int Y_,
|
||||||
index_t X,
|
int X_,
|
||||||
index_t ConvStrideH,
|
int ConvStrideH_,
|
||||||
index_t ConvStrideW,
|
int ConvStrideW_,
|
||||||
index_t ConvDilationH,
|
int ConvDilationH_,
|
||||||
index_t ConvDilationW,
|
int ConvDilationW_,
|
||||||
index_t InLeftPadH,
|
int InLeftPadH_,
|
||||||
index_t InLeftPadW,
|
int InLeftPadW_,
|
||||||
index_t InRightPadH,
|
int InRightPadH_,
|
||||||
index_t InRightPadW,
|
int InRightPadW_,
|
||||||
void* p_desc_tuple)
|
void* p_desc_tuple)
|
||||||
{
|
{
|
||||||
|
index_t N = static_cast<index_t>(N_);
|
||||||
|
index_t C = static_cast<index_t>(C_);
|
||||||
|
index_t Hi = static_cast<index_t>(Hi_);
|
||||||
|
index_t Wi = static_cast<index_t>(Wi_);
|
||||||
|
index_t K = static_cast<index_t>(K_);
|
||||||
|
index_t Y = static_cast<index_t>(Y_);
|
||||||
|
index_t X = static_cast<index_t>(X_);
|
||||||
|
index_t ConvStrideH = static_cast<index_t>(ConvStrideH_);
|
||||||
|
index_t ConvStrideW = static_cast<index_t>(ConvStrideW_);
|
||||||
|
index_t ConvDilationH = static_cast<index_t>(ConvDilationH_);
|
||||||
|
index_t ConvDilationW = static_cast<index_t>(ConvDilationW_);
|
||||||
|
index_t InLeftPadH = static_cast<index_t>(InLeftPadH_);
|
||||||
|
index_t InLeftPadW = static_cast<index_t>(InLeftPadW_);
|
||||||
|
index_t InRightPadH = static_cast<index_t>(InRightPadH_);
|
||||||
|
index_t InRightPadW = static_cast<index_t>(InRightPadW_);
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
constexpr auto I0 = Number<0>{};
|
||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
constexpr auto I2 = Number<2>{};
|
||||||
@@ -88,12 +104,9 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
|||||||
const index_t Wo =
|
const index_t Wo =
|
||||||
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
|
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
|
||||||
|
|
||||||
const auto in_n_c_hi_wi_desc =
|
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi));
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C, Hi, Wi));
|
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X));
|
||||||
const auto wei_k_c_y_x_desc =
|
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C, Y, X));
|
|
||||||
const auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
|
||||||
|
|
||||||
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||||
wei_k_c_y_x_desc,
|
wei_k_c_y_x_desc,
|
||||||
@@ -114,7 +127,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
|||||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||||
|
|
||||||
using AGridIteratorHacks =
|
using AGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||||
@@ -126,7 +139,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||||
|
|
||||||
using BGridIteratorHacks = decltype(make_tuple(
|
using BGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||||
@@ -138,7 +151,7 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(
|
using CGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||||
@@ -154,13 +167,13 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using BGridMoveSliceWindowIteratorHacks =
|
using BGridMoveSliceWindowStepHacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using GridwiseContraction =
|
using GridwiseContraction =
|
||||||
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
@@ -194,11 +207,11 @@ dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(inde
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||||
{
|
{
|
||||||
@@ -220,7 +233,7 @@ extern "C" __global__ void
|
|||||||
#if CK_USE_LAUNCH_BOUNDS
|
#if CK_USE_LAUNCH_BOUNDS
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||||
#endif
|
#endif
|
||||||
dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
const FloatAB* __restrict__ p_a_grid,
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
const FloatAB* __restrict__ p_b_grid,
|
||||||
FloatC* __restrict__ p_c_grid,
|
FloatC* __restrict__ p_c_grid,
|
||||||
@@ -232,11 +245,11 @@ extern "C" __global__ void
|
|||||||
constexpr auto I3 = Number<3>{};
|
constexpr auto I3 = Number<3>{};
|
||||||
|
|
||||||
constexpr auto in_n_c_hi_wi_desc =
|
constexpr auto in_n_c_hi_wi_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||||
constexpr auto wei_k_c_y_x_desc =
|
constexpr auto wei_k_c_y_x_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||||
constexpr auto out_n_k_ho_wo_desc =
|
constexpr auto out_n_k_ho_wo_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||||
|
|
||||||
constexpr auto descs =
|
constexpr auto descs =
|
||||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||||
@@ -257,7 +270,7 @@ extern "C" __global__ void
|
|||||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||||
|
|
||||||
using AGridIteratorHacks =
|
using AGridStepHacks =
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||||
@@ -269,7 +282,7 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||||
|
|
||||||
using BGridIteratorHacks = decltype(make_tuple(
|
using BGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||||
@@ -281,7 +294,7 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(
|
using CGridStepHacks = decltype(make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||||
@@ -297,13 +310,13 @@ extern "C" __global__ void
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using BGridMoveSliceWindowIteratorHacks =
|
using BGridMoveSliceWindowStepHacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||||
|
|
||||||
using GridwiseContraction =
|
using GridwiseContraction =
|
||||||
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
@@ -337,11 +350,11 @@ extern "C" __global__ void
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||||
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||||
@@ -1,374 +0,0 @@
|
|||||||
#include "common_header.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "gridwise_dynamic_gemm_dlops_v1r2.hpp"
|
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
|
||||||
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
|
||||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
|
||||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
|
||||||
|
|
||||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
|
||||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
|
||||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
|
||||||
|
|
||||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
|
||||||
|
|
||||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
|
||||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
|
||||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
|
||||||
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
|
|
||||||
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
|
|
||||||
constexpr index_t KPerThread = CK_PARAM_KPerThread;
|
|
||||||
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
|
|
||||||
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
|
|
||||||
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
|
|
||||||
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
|
|
||||||
|
|
||||||
using ABlockTransferThreadSliceLengths_K_M0_M1 =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1>;
|
|
||||||
using ABlockTransferThreadClusterLengths_K_M0_M1 =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1>;
|
|
||||||
using ABlockTransferThreadClusterArrangeOrder =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
|
||||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
|
||||||
|
|
||||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
|
||||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
|
||||||
constexpr index_t ABlockTransferDstScalarPerVector_M1 =
|
|
||||||
CK_PARAM_ABlockTransferDstScalarPerVector_M1;
|
|
||||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
|
||||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
using BBlockTransferThreadSliceLengths_K_N0_N1 =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1>;
|
|
||||||
using BBlockTransferThreadClusterLengths_K_N0_N1 =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1>;
|
|
||||||
using BBlockTransferThreadClusterArrangeOrder =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
|
||||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
|
||||||
|
|
||||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
|
||||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
|
||||||
constexpr index_t BBlockTransferDstScalarPerVector_N1 =
|
|
||||||
CK_PARAM_BBlockTransferDstScalarPerVector_N1;
|
|
||||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
|
||||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
|
||||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
|
||||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
|
||||||
|
|
||||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
|
||||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
|
||||||
|
|
||||||
extern "C" __global__ void
|
|
||||||
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
|
|
||||||
int n,
|
|
||||||
int c,
|
|
||||||
int hi,
|
|
||||||
int wi,
|
|
||||||
int k,
|
|
||||||
int y,
|
|
||||||
int x,
|
|
||||||
int convStrideH,
|
|
||||||
int convStrideW,
|
|
||||||
int convDilationY,
|
|
||||||
int convDilationX,
|
|
||||||
int leftPadH,
|
|
||||||
int leftPadW,
|
|
||||||
int rightPadH,
|
|
||||||
int rightPadW,
|
|
||||||
void* p_a_k_m0_m1_grid_desc,
|
|
||||||
void* p_b_k_n0_n1_grid_desc,
|
|
||||||
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
|
|
||||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
|
||||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
|
||||||
|
|
||||||
const auto in_n_c_hi_wi_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi));
|
|
||||||
const auto wei_k_c_y_x_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
|
|
||||||
const auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
|
|
||||||
|
|
||||||
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
|
||||||
wei_k_c_y_x_desc,
|
|
||||||
in_n_c_hi_wi_desc,
|
|
||||||
out_n_k_ho_wo_desc,
|
|
||||||
make_tuple(convStrideH, convStrideW),
|
|
||||||
make_tuple(convDilationY, convDilationX),
|
|
||||||
make_tuple(leftPadH, leftPadW),
|
|
||||||
make_tuple(rightPadH, rightPadW));
|
|
||||||
|
|
||||||
const auto a_k_m_grid_desc = descs[I0];
|
|
||||||
const auto b_k_n_grid_desc = descs[I1];
|
|
||||||
const auto c_m_n_grid_desc = descs[I2];
|
|
||||||
|
|
||||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
|
||||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using BGridIteratorHacks =
|
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
|
||||||
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
|
||||||
AKMGridDesc,
|
|
||||||
BKNGridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
M1PerThread,
|
|
||||||
N1PerThread,
|
|
||||||
KPerThread,
|
|
||||||
M1N1ThreadClusterM10,
|
|
||||||
M1N1ThreadClusterN10,
|
|
||||||
M1N1ThreadClusterM11,
|
|
||||||
M1N1ThreadClusterN11,
|
|
||||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
|
||||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorDim,
|
|
||||||
ABlockTransferSrcScalarPerVector,
|
|
||||||
ABlockTransferDstScalarPerVector_M1,
|
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
|
||||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorDim,
|
|
||||||
BBlockTransferSrcScalarPerVector,
|
|
||||||
BBlockTransferDstScalarPerVector_N1,
|
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
|
||||||
|
|
||||||
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
|
||||||
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
|
||||||
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
|
||||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
|
||||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
|
||||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
if(hipThreadIdx_x == 0)
|
|
||||||
{
|
|
||||||
*static_cast<decltype(a_k_m0_m1_grid_desc)*>(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc;
|
|
||||||
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
|
|
||||||
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
|
|
||||||
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
|
|
||||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
|
||||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
extern "C" __global__ void
|
|
||||||
#if CK_USE_LAUNCH_BOUNDS
|
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
|
||||||
#endif
|
|
||||||
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
|
||||||
FloatC* __restrict__ p_c_grid,
|
|
||||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
|
||||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
|
||||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
|
|
||||||
constexpr auto in_n_c_hi_wi_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
|
||||||
constexpr auto wei_k_c_y_x_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
|
||||||
constexpr auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
|
||||||
|
|
||||||
constexpr auto descs =
|
|
||||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
|
||||||
in_n_c_hi_wi_desc,
|
|
||||||
out_n_k_ho_wo_desc,
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1));
|
|
||||||
|
|
||||||
constexpr auto a_k_m_grid_desc = descs[I0];
|
|
||||||
constexpr auto b_k_n_grid_desc = descs[I1];
|
|
||||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
|
||||||
|
|
||||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
|
||||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using BGridIteratorHacks =
|
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
|
||||||
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
|
||||||
AKMGridDesc,
|
|
||||||
BKNGridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
M1PerThread,
|
|
||||||
N1PerThread,
|
|
||||||
KPerThread,
|
|
||||||
M1N1ThreadClusterM10,
|
|
||||||
M1N1ThreadClusterN10,
|
|
||||||
M1N1ThreadClusterM11,
|
|
||||||
M1N1ThreadClusterN11,
|
|
||||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
|
||||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorDim,
|
|
||||||
ABlockTransferSrcScalarPerVector,
|
|
||||||
ABlockTransferDstScalarPerVector_M1,
|
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
|
||||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorDim,
|
|
||||||
BBlockTransferSrcScalarPerVector,
|
|
||||||
BBlockTransferDstScalarPerVector_N1,
|
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
|
||||||
|
|
||||||
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
|
||||||
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
|
||||||
constexpr auto b_k_n0_n1_grid_desc_tmp =
|
|
||||||
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
|
||||||
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
|
|
||||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
|
||||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
|
||||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
|
|
||||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
|
|
||||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
|
|
||||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
|
||||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
|
||||||
|
|
||||||
const auto a_k_m0_m1_grid_desc =
|
|
||||||
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
|
|
||||||
const auto b_k_n0_n1_grid_desc =
|
|
||||||
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
|
|
||||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
|
||||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
|
||||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
|
||||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
|
||||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
|
||||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
|
|
||||||
constexpr index_t shared_block_size =
|
|
||||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
|
||||||
|
|
||||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
|
||||||
|
|
||||||
GridwiseGemm::Run(p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
p_shared_block,
|
|
||||||
a_k_m0_m1_grid_desc,
|
|
||||||
b_k_n0_n1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
|
||||||
integral_constant<bool, HasMainKBlockLoop>{},
|
|
||||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
|
||||||
};
|
|
||||||
@@ -1,362 +0,0 @@
|
|||||||
#include "common_header.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
|
||||||
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
|
||||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
|
||||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
|
||||||
|
|
||||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
|
||||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
|
||||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
|
||||||
|
|
||||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
|
||||||
|
|
||||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
|
||||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
|
||||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
|
||||||
|
|
||||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
|
||||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
|
||||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
|
||||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
|
||||||
constexpr index_t K1 = CK_PARAM_K1;
|
|
||||||
|
|
||||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
|
||||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
|
||||||
using ABlockTransferThreadClusterArrangeOrder =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
|
||||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
|
||||||
|
|
||||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
|
||||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
|
||||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
|
||||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
|
||||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
|
||||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
|
||||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
|
||||||
using BBlockTransferThreadClusterArrangeOrder =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
|
||||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
|
||||||
|
|
||||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
|
||||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
|
||||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
|
||||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
|
||||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
|
||||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
|
||||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
|
||||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
|
||||||
|
|
||||||
extern "C" __global__ void
|
|
||||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
|
|
||||||
int n,
|
|
||||||
int c,
|
|
||||||
int hi,
|
|
||||||
int wi,
|
|
||||||
int k,
|
|
||||||
int y,
|
|
||||||
int x,
|
|
||||||
int convStrideH,
|
|
||||||
int convStrideW,
|
|
||||||
int convDilationY,
|
|
||||||
int convDilationX,
|
|
||||||
int leftPadH,
|
|
||||||
int leftPadW,
|
|
||||||
int rightPadH,
|
|
||||||
int rightPadW,
|
|
||||||
void* p_a_k0_m_k1_grid_desc,
|
|
||||||
void* p_b_k0_n_k1_grid_desc,
|
|
||||||
void* p_c_m0_m1_m2_n_grid_desc,
|
|
||||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
|
|
||||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
|
||||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
|
||||||
|
|
||||||
const auto in_n_c_hi_wi_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi));
|
|
||||||
const auto wei_k_c_y_x_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
|
|
||||||
const auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
|
|
||||||
|
|
||||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
|
||||||
wei_k_c_y_x_desc,
|
|
||||||
in_n_c_hi_wi_desc,
|
|
||||||
out_n_k_ho_wo_desc,
|
|
||||||
make_tuple(convStrideH, convStrideW),
|
|
||||||
make_tuple(convDilationY, convDilationX),
|
|
||||||
make_tuple(leftPadH, leftPadW),
|
|
||||||
make_tuple(rightPadH, rightPadW),
|
|
||||||
Number<K1>{});
|
|
||||||
|
|
||||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
|
||||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
|
||||||
const auto c_m_n_grid_desc = descs[I2];
|
|
||||||
|
|
||||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
|
||||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using AGridIteratorHacks = decltype(make_tuple(
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
|
||||||
make_tuple(
|
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using BGridIteratorHacks =
|
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
|
||||||
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
InMemoryDataOperationEnum_t::Set,
|
|
||||||
AK0MK1GridDesc,
|
|
||||||
BK0NK1GridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
MPerWave,
|
|
||||||
NPerWave,
|
|
||||||
K1,
|
|
||||||
MRepeat,
|
|
||||||
NRepeat,
|
|
||||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorDim,
|
|
||||||
ABlockTransferSrcScalarPerVector,
|
|
||||||
ABlockTransferDstScalarPerVector_K1,
|
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorDim,
|
|
||||||
BBlockTransferSrcScalarPerVector,
|
|
||||||
BBlockTransferDstScalarPerVector_K1,
|
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
|
||||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
if(hipThreadIdx_x == 0)
|
|
||||||
{
|
|
||||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
|
||||||
a_k0_m_k1_grid_desc;
|
|
||||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
|
||||||
b_k0_n_k1_grid_desc;
|
|
||||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
|
||||||
c_m0_m1_m2_n_grid_desc;
|
|
||||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
|
||||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
extern "C" __global__ void
|
|
||||||
#if CK_USE_LAUNCH_BOUNDS
|
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
|
||||||
#endif
|
|
||||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
|
||||||
FloatC* __restrict__ p_c_grid,
|
|
||||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
|
||||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
|
||||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
|
||||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
|
||||||
{
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
|
|
||||||
constexpr auto in_n_c_hi_wi_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
|
||||||
constexpr auto wei_k_c_y_x_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
|
|
||||||
constexpr auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
|
||||||
|
|
||||||
constexpr auto descs =
|
|
||||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
|
||||||
in_n_c_hi_wi_desc,
|
|
||||||
out_n_k_ho_wo_desc,
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1),
|
|
||||||
Number<K1>{});
|
|
||||||
|
|
||||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
|
||||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
|
||||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
|
||||||
|
|
||||||
using AGridIteratorHacks = decltype(make_tuple(
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
|
||||||
make_tuple(
|
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using BGridIteratorHacks =
|
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
|
||||||
|
|
||||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
|
||||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
InMemoryDataOperationEnum_t::Set,
|
|
||||||
AK0MK1GridDesc,
|
|
||||||
BK0NK1GridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
MPerWave,
|
|
||||||
NPerWave,
|
|
||||||
K1,
|
|
||||||
MRepeat,
|
|
||||||
NRepeat,
|
|
||||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorDim,
|
|
||||||
ABlockTransferSrcScalarPerVector,
|
|
||||||
ABlockTransferDstScalarPerVector_K1,
|
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorDim,
|
|
||||||
BBlockTransferSrcScalarPerVector,
|
|
||||||
BBlockTransferDstScalarPerVector_K1,
|
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
|
||||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
|
||||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
|
||||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
|
||||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
|
||||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
|
||||||
|
|
||||||
const auto a_k0_m_k1_grid_desc =
|
|
||||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
|
||||||
const auto b_k0_n_k1_grid_desc =
|
|
||||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
|
||||||
const auto c_m0_m1_m2_n_grid_desc =
|
|
||||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
|
||||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
|
||||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
|
||||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
|
|
||||||
constexpr index_t shared_block_size =
|
|
||||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
|
||||||
|
|
||||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
|
||||||
|
|
||||||
GridwiseGemm::Run(p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
p_shared_block,
|
|
||||||
a_k0_m_k1_grid_desc,
|
|
||||||
b_k0_n_k1_grid_desc,
|
|
||||||
c_m0_m1_m2_n_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
};
|
|
||||||
@@ -1,362 +0,0 @@
|
|||||||
#include "common_header.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
|
||||||
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
|
||||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
|
||||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
|
||||||
|
|
||||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
|
||||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
|
||||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
|
||||||
|
|
||||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
|
||||||
|
|
||||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
|
||||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
|
||||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
|
||||||
|
|
||||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
|
||||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
|
||||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
|
||||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
|
||||||
constexpr index_t K1 = CK_PARAM_K1;
|
|
||||||
|
|
||||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
|
||||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
|
||||||
using ABlockTransferThreadClusterArrangeOrder =
|
|
||||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
|
||||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
|
||||||
|
|
||||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
|
||||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
|
||||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
|
||||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
|
||||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
|
||||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
|
||||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
|
||||||
using BBlockTransferThreadClusterArrangeOrder =
|
|
||||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
|
||||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
|
||||||
|
|
||||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
|
||||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
|
||||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
|
||||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
|
||||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
|
||||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
|
||||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
|
||||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
|
||||||
|
|
||||||
extern "C" __global__ void
|
|
||||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
|
|
||||||
int n,
|
|
||||||
int hi,
|
|
||||||
int wi,
|
|
||||||
int c,
|
|
||||||
int k,
|
|
||||||
int y,
|
|
||||||
int x,
|
|
||||||
int convStrideH,
|
|
||||||
int convStrideW,
|
|
||||||
int convDilationY,
|
|
||||||
int convDilationX,
|
|
||||||
int leftPadH,
|
|
||||||
int leftPadW,
|
|
||||||
int rightPadH,
|
|
||||||
int rightPadW,
|
|
||||||
void* p_a_k0_m_k1_grid_desc,
|
|
||||||
void* p_b_k0_n_k1_grid_desc,
|
|
||||||
void* p_c_m0_m1_m2_n_grid_desc,
|
|
||||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
|
||||||
{
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
|
|
||||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
|
||||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
|
||||||
|
|
||||||
const auto in_n_hi_wi_c_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, hi, wi, c));
|
|
||||||
const auto wei_k_y_x_c_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, y, x, c));
|
|
||||||
const auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, ho, wo, k));
|
|
||||||
|
|
||||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
|
||||||
in_n_hi_wi_c_desc,
|
|
||||||
wei_k_y_x_c_desc,
|
|
||||||
out_n_ho_wo_k_desc,
|
|
||||||
make_tuple(convStrideH, convStrideW),
|
|
||||||
make_tuple(convDilationY, convDilationX),
|
|
||||||
make_tuple(leftPadH, leftPadW),
|
|
||||||
make_tuple(rightPadH, rightPadW),
|
|
||||||
Number<K1>{});
|
|
||||||
|
|
||||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
|
||||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
|
||||||
const auto c_m_n_grid_desc = descs[I2];
|
|
||||||
|
|
||||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
|
||||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using BGridIteratorHacks = decltype(make_tuple(
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
|
||||||
make_tuple(
|
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using AGridIteratorHacks =
|
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
|
||||||
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
InMemoryDataOperationEnum_t::Set,
|
|
||||||
AK0MK1GridDesc,
|
|
||||||
BK0NK1GridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
MPerWave,
|
|
||||||
NPerWave,
|
|
||||||
K1,
|
|
||||||
MRepeat,
|
|
||||||
NRepeat,
|
|
||||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorDim,
|
|
||||||
ABlockTransferSrcScalarPerVector,
|
|
||||||
ABlockTransferDstScalarPerVector_K1,
|
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorDim,
|
|
||||||
BBlockTransferSrcScalarPerVector,
|
|
||||||
BBlockTransferDstScalarPerVector_K1,
|
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
|
||||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
if(hipThreadIdx_x == 0)
|
|
||||||
{
|
|
||||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
|
||||||
a_k0_m_k1_grid_desc;
|
|
||||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
|
||||||
b_k0_n_k1_grid_desc;
|
|
||||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
|
||||||
c_m0_m1_m2_n_grid_desc;
|
|
||||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
|
||||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
extern "C" __global__ void
|
|
||||||
#if CK_USE_LAUNCH_BOUNDS
|
|
||||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
|
||||||
#endif
|
|
||||||
dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
|
||||||
const FloatAB* __restrict__ p_a_grid,
|
|
||||||
const FloatAB* __restrict__ p_b_grid,
|
|
||||||
FloatC* __restrict__ p_c_grid,
|
|
||||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
|
||||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
|
||||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
|
||||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
|
||||||
{
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
constexpr auto in_n_hi_wi_c_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256));
|
|
||||||
constexpr auto wei_k_y_x_c_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 3, 3, 256));
|
|
||||||
constexpr auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 28, 28, 256));
|
|
||||||
|
|
||||||
constexpr auto descs =
|
|
||||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
|
||||||
wei_k_y_x_c_desc,
|
|
||||||
out_n_ho_wo_k_desc,
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1),
|
|
||||||
make_tuple(1, 1),
|
|
||||||
Number<K1>{});
|
|
||||||
|
|
||||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
|
||||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
|
||||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
|
||||||
|
|
||||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
|
||||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
|
||||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using BGridIteratorHacks = decltype(make_tuple(
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
|
||||||
make_tuple(
|
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using AGridIteratorHacks =
|
|
||||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
|
||||||
|
|
||||||
using CGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 1, 0, 0>{}),
|
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
|
||||||
Sequence<0, 0, 2, 0, 0>{})));
|
|
||||||
|
|
||||||
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
|
||||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0>;
|
|
||||||
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
InMemoryDataOperationEnum_t::Set,
|
|
||||||
AK0MK1GridDesc,
|
|
||||||
BK0NK1GridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
MPerWave,
|
|
||||||
NPerWave,
|
|
||||||
K1,
|
|
||||||
MRepeat,
|
|
||||||
NRepeat,
|
|
||||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorDim,
|
|
||||||
ABlockTransferSrcScalarPerVector,
|
|
||||||
ABlockTransferDstScalarPerVector_K1,
|
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorDim,
|
|
||||||
BBlockTransferSrcScalarPerVector,
|
|
||||||
BBlockTransferDstScalarPerVector_K1,
|
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
|
||||||
false>;
|
|
||||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
|
||||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
|
||||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
|
||||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
|
||||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
|
||||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
|
||||||
|
|
||||||
const auto a_k0_m_k1_grid_desc =
|
|
||||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
|
||||||
const auto b_k0_n_k1_grid_desc =
|
|
||||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
|
||||||
const auto c_m0_m1_m2_n_grid_desc =
|
|
||||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
|
||||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
|
||||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
|
||||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
|
|
||||||
constexpr index_t shared_block_size =
|
|
||||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
|
||||||
|
|
||||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
|
||||||
|
|
||||||
GridwiseGemm::Run(p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
p_shared_block,
|
|
||||||
a_k0_m_k1_grid_desc,
|
|
||||||
b_k0_n_k1_grid_desc,
|
|
||||||
c_m0_m1_m2_n_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
};
|
|
||||||
5670
external/half/include/half.hpp
vendored
5670
external/half/include/half.hpp
vendored
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,2 @@
|
|||||||
add_subdirectory(host_tensor)
|
add_subdirectory(host_tensor)
|
||||||
add_subdirectory(online_compile)
|
|
||||||
add_subdirectory(driver_offline)
|
add_subdirectory(driver_offline)
|
||||||
add_subdirectory(driver_online)
|
|
||||||
|
|||||||
@@ -9,11 +9,10 @@ include_directories(BEFORE
|
|||||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||||
${PROJECT_SOURCE_DIR}/external/half/include
|
|
||||||
)
|
)
|
||||||
|
|
||||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE conv_fwd_driver_offline.cpp)
|
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
|
||||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE conv_bwd_driver_offline.cpp)
|
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
|
||||||
|
|
||||||
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
||||||
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
|
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
|
||||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||||
const InLengths& in_n_hi_wi_c_lengths,
|
const InLengths& in_n_hi_wi_c_lengths,
|
||||||
const WeiLengths& wei_k_y_x_c_lengths,
|
const WeiLengths& wei_k_y_x_c_lengths,
|
||||||
const OutLengths& out_n_ho_wo_k_lengths,
|
const OutLengths& out_n_ho_wo_k_lengths,
|
||||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
constexpr auto I2 = Number<2>{};
|
||||||
constexpr auto I3 = Number<3>{};
|
constexpr auto I3 = Number<3>{};
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
constexpr auto I6 = Number<6>{};
|
|
||||||
constexpr auto I7 = Number<7>{};
|
|
||||||
constexpr auto I8 = Number<8>{};
|
|
||||||
|
|
||||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||||
@@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||||
|
|
||||||
const auto in_n_hi_wi_c_desc =
|
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||||
const auto wei_k_y_x_c_desc =
|
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
|
||||||
const auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||||
@@ -215,7 +207,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||||
@@ -223,7 +215,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||||
|
|
||||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||||
@@ -231,7 +223,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||||
|
|
||||||
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple(
|
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
|
||||||
@@ -251,15 +243,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
float ave_time = driver_gemm_xdlops_v2r3<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -295,11 +287,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
|
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
|
||||||
6,
|
6,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(in_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
false // CAccessOrderMRepeatNRepeat
|
false // CAccessOrderMRepeatNRepeat
|
||||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||||
@@ -307,11 +299,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
in_gemmm_gemmn_grid_desc,
|
in_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
out_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
in_m0_m1_m2_n_grid_iterator_hacks,
|
in_m0_m1_m2_n_grid_step_hacks,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -319,16 +311,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyx
|
|||||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||||
const auto C = wei_k_y_x_c_lengths[I3];
|
const auto C = wei_k_y_x_c_lengths[I3];
|
||||||
|
|
||||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
|
||||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
|
||||||
|
|
||||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||||
|
|
||||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||||
const auto X = wei_k_y_x_c_lengths[I2];
|
const auto X = wei_k_y_x_c_lengths[I2];
|
||||||
|
|
||||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
||||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
|
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
|
||||||
const InLengths& in_n_hi_wi_c_lengths,
|
const InLengths& in_n_hi_wi_c_lengths,
|
||||||
const WeiLengths& wei_k_y_x_c_lengths,
|
const WeiLengths& wei_k_y_x_c_lengths,
|
||||||
const OutLengths& out_n_ho_wo_k_lengths,
|
const OutLengths& out_n_ho_wo_k_lengths,
|
||||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
constexpr auto I2 = Number<2>{};
|
||||||
constexpr auto I3 = Number<3>{};
|
constexpr auto I3 = Number<3>{};
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
constexpr auto I6 = Number<6>{};
|
|
||||||
constexpr auto I7 = Number<7>{};
|
|
||||||
constexpr auto I8 = Number<8>{};
|
|
||||||
|
|
||||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||||
@@ -49,12 +44,9 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||||
|
|
||||||
const auto in_n_hi_wi_c_desc =
|
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||||
const auto wei_k_y_x_c_desc =
|
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
|
||||||
const auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||||
@@ -187,7 +179,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||||
@@ -195,7 +187,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||||
@@ -203,7 +195,7 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||||
|
|
||||||
constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple(
|
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||||
@@ -223,15 +215,15 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
|
||||||
|
|
||||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
float ave_time = driver_gemm_xdlops_v2r3<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -271,11 +263,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
#endif
|
#endif
|
||||||
7,
|
7,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(in_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
true // CAccessOrderMRepeatNRepeat
|
true // CAccessOrderMRepeatNRepeat
|
||||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
@@ -283,11 +275,11 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
in_gemmm_gemmn_grid_desc,
|
in_gemmm_gemmn_grid_desc,
|
||||||
out_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
in_m0_m1_m2_n_grid_iterator_hacks,
|
in_m0_m1_m2_n_grid_step_hacks,
|
||||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -295,16 +287,13 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k
|
|||||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||||
const auto C = wei_k_y_x_c_lengths[I3];
|
const auto C = wei_k_y_x_c_lengths[I3];
|
||||||
|
|
||||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
|
||||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
|
||||||
|
|
||||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||||
|
|
||||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||||
const auto X = wei_k_y_x_c_lengths[I2];
|
const auto X = wei_k_y_x_c_lengths[I2];
|
||||||
|
|
||||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||||
#include "driver_dynamic_gemm_dlops_v1r2.hpp"
|
#include "driver_gemm_dlops_v1r2.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||||
const InLengths& in_n_c_hi_wi_lengths,
|
const InLengths& in_n_c_hi_wi_lengths,
|
||||||
const WeiLengths& wei_k_c_y_x_lengths,
|
const WeiLengths& wei_k_c_y_x_lengths,
|
||||||
const OutLengths& out_n_k_ho_wo_lengths,
|
const OutLengths& out_n_k_ho_wo_lengths,
|
||||||
@@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
constexpr auto I0 = Number<0>{};
|
constexpr auto I0 = Number<0>{};
|
||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
constexpr auto I2 = Number<2>{};
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
constexpr auto I6 = Number<6>{};
|
|
||||||
constexpr auto I7 = Number<7>{};
|
|
||||||
constexpr auto I8 = Number<8>{};
|
|
||||||
|
|
||||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||||
@@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||||
|
|
||||||
const auto in_n_c_hi_wi_desc =
|
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||||
const auto wei_k_c_y_x_desc =
|
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
|
||||||
const auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
// cdata = 64, BlockSize = 256, 128x128x8
|
// cdata = 64, BlockSize = 256, 128x128x8
|
||||||
@@ -98,7 +89,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
in_right_pads);
|
in_right_pads);
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
|
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -108,7 +99,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
|
constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||||
@@ -116,7 +107,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -130,10 +121,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||||
@@ -142,7 +133,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_gemm_dlops_v1r2<
|
float ave_time = driver_gemm_dlops_v1r2<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -180,26 +171,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||||
5, // CThreadTransferSrcDstVectorDim
|
5, // CThreadTransferSrcDstVectorDim
|
||||||
GemmCThreadTransferDstScalarPerVector_N11,
|
GemmCThreadTransferDstScalarPerVector_N11,
|
||||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
|
decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
|
||||||
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
|
decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
|
||||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
|
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
|
||||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||||
wei_gemmk_gemmm_grid_desc,
|
wei_gemmk_gemmm_grid_desc,
|
||||||
in_gemmk_gemmn_grid_desc,
|
in_gemmk_gemmn_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
|
wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
|
||||||
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
|
in_gemmk_gemmn0_gemmn1_grid_step_hacks,
|
||||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
|
||||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
|
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
float perf = (float)calculate_convolution_flops(
|
float perf = static_cast<float>(calculate_convolution_flops(
|
||||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -13,7 +13,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||||
const InLengths& in_n_c_hi_wi_lengths,
|
const InLengths& in_n_c_hi_wi_lengths,
|
||||||
const WeiLengths& wei_k_c_y_x_lengths,
|
const WeiLengths& wei_k_c_y_x_lengths,
|
||||||
const OutLengths& out_n_k_ho_wo_lengths,
|
const OutLengths& out_n_k_ho_wo_lengths,
|
||||||
@@ -48,12 +48,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
|
|||||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||||
|
|
||||||
const auto in_n_c_hi_wi_desc =
|
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||||
const auto wei_k_c_y_x_desc =
|
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
|
||||||
const auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
constexpr index_t BlockSize = 256;
|
constexpr index_t BlockSize = 256;
|
||||||
@@ -212,9 +209,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
|
|||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
#if 0
|
#if 0
|
||||||
float ave_time = launch_kernel_dynamic_gemm_xdlops_v1
|
float ave_time = launch_kernel_gemm_xdlops_v1
|
||||||
#else
|
#else
|
||||||
float ave_time = launch_kernel_dynamic_gemm_xdlops_v2
|
float ave_time = launch_kernel_gemm_xdlops_v2
|
||||||
#endif
|
#endif
|
||||||
<BlockSize,
|
<BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||||
#include "driver_dynamic_gemm_dlops_v1r3.hpp"
|
#include "driver_gemm_dlops_v1r3.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
||||||
const InLengths& in_n_hi_wi_c_lengths,
|
const InLengths& in_n_hi_wi_c_lengths,
|
||||||
const WeiLengths& wei_k_y_x_c_lengths,
|
const WeiLengths& wei_k_y_x_c_lengths,
|
||||||
const OutLengths& out_n_ho_wo_k_lengths,
|
const OutLengths& out_n_ho_wo_k_lengths,
|
||||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
|||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
constexpr auto I2 = Number<2>{};
|
||||||
constexpr auto I3 = Number<3>{};
|
constexpr auto I3 = Number<3>{};
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
constexpr auto I6 = Number<6>{};
|
|
||||||
constexpr auto I7 = Number<7>{};
|
|
||||||
constexpr auto I8 = Number<8>{};
|
|
||||||
|
|
||||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||||
@@ -49,14 +44,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
|||||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||||
|
|
||||||
const auto in_n_hi_wi_c_desc =
|
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||||
const auto wei_k_y_x_c_desc =
|
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
|
||||||
const auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
|
||||||
|
|
||||||
#if 1
|
#if 0
|
||||||
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
|
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
|
||||||
// cdata = 64, BlockSize = 256
|
// cdata = 64, BlockSize = 256
|
||||||
constexpr index_t BlockSize = 256;
|
constexpr index_t BlockSize = 256;
|
||||||
@@ -163,7 +155,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
||||||
@@ -173,7 +165,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
||||||
@@ -183,7 +175,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||||
|
|
||||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
||||||
@@ -197,15 +189,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
|||||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
||||||
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_gemm_dlops_v1r3<
|
float ave_time = driver_gemm_dlops_v1r3<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -239,22 +231,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
|||||||
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||||
5, // CThreadTransferSrcDstVectorDim
|
5, // CThreadTransferSrcDstVectorDim
|
||||||
GemmCThreadTransferDstScalarPerVector_N11,
|
GemmCThreadTransferDstScalarPerVector_N11,
|
||||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
|
||||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
|
||||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
|
||||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -262,16 +254,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
|
|||||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||||
const auto C = wei_k_y_x_c_lengths[I3];
|
const auto C = wei_k_y_x_c_lengths[I3];
|
||||||
|
|
||||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
|
||||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
|
||||||
|
|
||||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||||
|
|
||||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||||
const auto X = wei_k_y_x_c_lengths[I2];
|
const auto X = wei_k_y_x_c_lengths[I2];
|
||||||
|
|
||||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||||
const InLengths& in_n_c_hi_wi_lengths,
|
const InLengths& in_n_c_hi_wi_lengths,
|
||||||
const WeiLengths& wei_k_c_y_x_lengths,
|
const WeiLengths& wei_k_c_y_x_lengths,
|
||||||
const OutLengths& out_n_k_ho_wo_lengths,
|
const OutLengths& out_n_k_ho_wo_lengths,
|
||||||
@@ -34,12 +34,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|||||||
constexpr auto I0 = Number<0>{};
|
constexpr auto I0 = Number<0>{};
|
||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
constexpr auto I2 = Number<2>{};
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
constexpr auto I6 = Number<6>{};
|
|
||||||
constexpr auto I7 = Number<7>{};
|
|
||||||
constexpr auto I8 = Number<8>{};
|
|
||||||
|
|
||||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||||
@@ -49,12 +43,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|||||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||||
|
|
||||||
const auto in_n_c_hi_wi_desc =
|
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||||
const auto wei_k_c_y_x_desc =
|
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
|
||||||
const auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||||
@@ -101,12 +92,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -114,7 +105,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -132,15 +123,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
float ave_time = driver_gemm_xdlops_v2r3<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -176,26 +167,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|||||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||||
7,
|
7,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
out_m0_m1_m2_n_grid_step_hacks,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
float perf = (float)calculate_convolution_flops(
|
float perf = static_cast<float>(calculate_convolution_flops(
|
||||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||||
#include "driver_dynamic_gemm_xdlops_v2r2.hpp"
|
#include "driver_gemm_xdlops_v2r2.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
||||||
const InLengths& in_n_hi_wi_c_lengths,
|
const InLengths& in_n_hi_wi_c_lengths,
|
||||||
const WeiLengths& wei_k_y_x_c_lengths,
|
const WeiLengths& wei_k_y_x_c_lengths,
|
||||||
const OutLengths& out_n_ho_wo_k_lengths,
|
const OutLengths& out_n_ho_wo_k_lengths,
|
||||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
|||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
constexpr auto I2 = Number<2>{};
|
||||||
constexpr auto I3 = Number<3>{};
|
constexpr auto I3 = Number<3>{};
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
constexpr auto I6 = Number<6>{};
|
|
||||||
constexpr auto I7 = Number<7>{};
|
|
||||||
constexpr auto I8 = Number<8>{};
|
|
||||||
|
|
||||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||||
@@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
|||||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||||
|
|
||||||
const auto in_n_hi_wi_c_desc =
|
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||||
const auto wei_k_y_x_c_desc =
|
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
|
||||||
const auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||||
@@ -129,12 +121,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -142,7 +134,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -152,15 +144,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_gemm_xdlops_v2r2<
|
float ave_time = driver_gemm_xdlops_v2r2<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -195,22 +187,22 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<2, 3, 0, 1>,
|
Sequence<2, 3, 0, 1>,
|
||||||
2,
|
2,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
out_m0_m1_m2_n_grid_step_hacks,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -218,9 +210,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
|||||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||||
const auto C = wei_k_y_x_c_lengths[I3];
|
const auto C = wei_k_y_x_c_lengths[I3];
|
||||||
|
|
||||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
|
||||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
|
||||||
|
|
||||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||||
|
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
||||||
const InLengths& in_n_hi_wi_c_lengths,
|
const InLengths& in_n_hi_wi_c_lengths,
|
||||||
const WeiLengths& wei_k_y_x_c_lengths,
|
const WeiLengths& wei_k_y_x_c_lengths,
|
||||||
const OutLengths& out_n_ho_wo_k_lengths,
|
const OutLengths& out_n_ho_wo_k_lengths,
|
||||||
@@ -49,12 +49,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
|||||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||||
|
|
||||||
const auto in_n_hi_wi_c_desc =
|
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||||
const auto wei_k_y_x_c_desc =
|
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
|
||||||
const auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||||
@@ -185,12 +182,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||||
@@ -198,7 +195,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 1, 0, 0>{},
|
Sequence<0, 0, 1, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -216,15 +213,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 2, 0, 0>{}));
|
Sequence<0, 0, 2, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
float ave_time = driver_gemm_xdlops_v2r3<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -259,11 +256,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||||
6,
|
6,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
false // CAccessOrderMRepeatNRepeat
|
false // CAccessOrderMRepeatNRepeat
|
||||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||||
@@ -271,11 +268,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
|||||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
out_m0_m1_m2_n_grid_step_hacks,
|
||||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||||
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
|
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -14,7 +14,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||||
const InLengths& in_n_hi_wi_c_lengths,
|
const InLengths& in_n_hi_wi_c_lengths,
|
||||||
const WeiLengths& wei_k_y_x_c_lengths,
|
const WeiLengths& wei_k_y_x_c_lengths,
|
||||||
const OutLengths& out_n_ho_wo_k_lengths,
|
const OutLengths& out_n_ho_wo_k_lengths,
|
||||||
@@ -35,11 +35,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
constexpr auto I2 = Number<2>{};
|
||||||
constexpr auto I3 = Number<3>{};
|
constexpr auto I3 = Number<3>{};
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
constexpr auto I6 = Number<6>{};
|
|
||||||
constexpr auto I7 = Number<7>{};
|
|
||||||
constexpr auto I8 = Number<8>{};
|
|
||||||
|
|
||||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||||
@@ -49,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||||
|
|
||||||
const auto in_n_hi_wi_c_desc =
|
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||||
const auto wei_k_y_x_c_desc =
|
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
|
||||||
const auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||||
@@ -241,7 +233,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_iterator_hacks =
|
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
||||||
@@ -249,7 +241,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||||
@@ -257,7 +249,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||||
|
|
||||||
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
|
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
|
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||||
@@ -275,15 +267,15 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
||||||
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
||||||
|
|
||||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
float ave_time = driver_gemm_xdlops_v2r3<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -319,11 +311,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||||
7,
|
7,
|
||||||
GemmCThreadTransferDstScalarPerVector,
|
GemmCThreadTransferDstScalarPerVector,
|
||||||
decltype(in_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
|
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
|
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||||
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
|
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
|
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||||
false // CAccessOrderMRepeatNRepeat
|
false // CAccessOrderMRepeatNRepeat
|
||||||
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||||
@@ -331,11 +323,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||||
out_gemmm_gemmn_grid_desc,
|
out_gemmm_gemmn_grid_desc,
|
||||||
in_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
|
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
|
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||||
out_m0_m1_m2_n_grid_iterator_hacks,
|
out_m0_m1_m2_n_grid_step_hacks,
|
||||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
|
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
|
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -343,16 +335,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
|||||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||||
const auto C = wei_k_y_x_c_lengths[I3];
|
const auto C = wei_k_y_x_c_lengths[I3];
|
||||||
|
|
||||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
|
||||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
|
||||||
|
|
||||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||||
|
|
||||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||||
const auto X = wei_k_y_x_c_lengths[I2];
|
const auto X = wei_k_y_x_c_lengths[I2];
|
||||||
|
|
||||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
|
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
ck::index_t InWeiVectorSize,
|
ck::index_t InWeiVectorSize,
|
||||||
@@ -15,7 +15,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
||||||
const InLengths& in_n_c_hi_wi_lengths,
|
const InLengths& in_n_c_hi_wi_lengths,
|
||||||
const WeiLengths& wei_k_c_y_x_lengths,
|
const WeiLengths& wei_k_c_y_x_lengths,
|
||||||
const OutLengths& out_n_k_ho_wo_lengths,
|
const OutLengths& out_n_k_ho_wo_lengths,
|
||||||
@@ -26,7 +26,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
|||||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||||
const Tensor<TInWei>& wei_k_c_y_x,
|
const Tensor<TInWei>& wei_k_c_y_x,
|
||||||
Tensor<TOut>& out_n_k_ho_wo,
|
Tensor<TOut>& out_n_k_ho_wo,
|
||||||
ck::index_t nrepeat)
|
ck::index_t /* nrepeat */)
|
||||||
{
|
{
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
|
|
||||||
@@ -85,12 +85,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
|||||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||||
|
|
||||||
const auto in_n_c0_hi_wi_desc =
|
const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi));
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
|
const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X));
|
||||||
const auto wei_k_c0_y_x_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
|
|
||||||
const auto out_n_k0_ho_wo_k1_desc =
|
const auto out_n_k0_ho_wo_k1_desc =
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
|
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
#include "device.hpp"
|
#include "device.hpp"
|
||||||
#include "host_tensor.hpp"
|
#include "host_tensor.hpp"
|
||||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||||
#include "driver_dynamic_contraction_dlops_v1r2.hpp"
|
#include "driver_contraction_dlops_v1r2.hpp"
|
||||||
|
|
||||||
template <typename TInWei,
|
template <typename TInWei,
|
||||||
typename TAcc,
|
typename TAcc,
|
||||||
@@ -15,7 +15,7 @@ template <typename TInWei,
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||||
const InLengths& in_n_c_hi_wi_lengths,
|
const InLengths& in_n_c_hi_wi_lengths,
|
||||||
const WeiLengths& wei_k_c_y_x_lengths,
|
const WeiLengths& wei_k_c_y_x_lengths,
|
||||||
const OutLengths& out_n_k_ho_wo_lengths,
|
const OutLengths& out_n_k_ho_wo_lengths,
|
||||||
@@ -44,12 +44,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||||
|
|
||||||
const auto in_desc_n_c_hi_wi =
|
const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||||
const auto wei_desc_k_c_y_x =
|
const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
|
||||||
const auto out_desc_n_k_ho_wo =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
|
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
|
||||||
@@ -133,7 +130,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||||
|
|
||||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||||
constexpr auto wei_grid_iterator_hacks =
|
constexpr auto wei_grid_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||||
@@ -145,7 +142,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||||
|
|
||||||
constexpr auto in_grid_iterator_hacks = make_tuple(
|
constexpr auto in_grid_step_hacks = make_tuple(
|
||||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||||
@@ -157,7 +154,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||||
|
|
||||||
constexpr auto out_grid_iterator_hacks = make_tuple(
|
constexpr auto out_grid_step_hacks = make_tuple(
|
||||||
make_tuple(
|
make_tuple(
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||||
@@ -173,14 +170,14 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
||||||
|
|
||||||
constexpr auto wei_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto in_grid_move_slice_window_iterator_hacks =
|
constexpr auto in_grid_move_slice_window_step_hacks =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
||||||
|
|
||||||
for(index_t i = 0; i < 5; ++i)
|
for(index_t i = 0; i < 5; ++i)
|
||||||
{
|
{
|
||||||
float ave_time = driver_dynamic_contraction_dlops_v1r2<
|
float ave_time = driver_contraction_dlops_v1r2<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
TInWei,
|
TInWei,
|
||||||
TAcc,
|
TAcc,
|
||||||
@@ -214,26 +211,26 @@ void device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|||||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||||
5, // CThreadTransferSrcDstVectorDim
|
5, // CThreadTransferSrcDstVectorDim
|
||||||
CThreadTransferDstScalarPerVector_BN1,
|
CThreadTransferDstScalarPerVector_BN1,
|
||||||
decltype(wei_grid_iterator_hacks),
|
decltype(wei_grid_step_hacks),
|
||||||
decltype(in_grid_iterator_hacks),
|
decltype(in_grid_step_hacks),
|
||||||
decltype(out_grid_iterator_hacks),
|
decltype(out_grid_step_hacks),
|
||||||
decltype(wei_grid_move_slice_window_iterator_hacks),
|
decltype(wei_grid_move_slice_window_step_hacks),
|
||||||
decltype(in_grid_move_slice_window_iterator_hacks)>(
|
decltype(in_grid_move_slice_window_step_hacks)>(
|
||||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||||
wei_grid_desc_gk0_gm0_gm1_gk1,
|
wei_grid_desc_gk0_gm0_gm1_gk1,
|
||||||
in_grid_desc_gk0_gn0_gn1_gk1,
|
in_grid_desc_gk0_gn0_gn1_gk1,
|
||||||
out_grid_desc_gm0_gm1_gn0_gn1,
|
out_grid_desc_gm0_gm1_gn0_gn1,
|
||||||
wei_grid_iterator_hacks,
|
wei_grid_step_hacks,
|
||||||
in_grid_iterator_hacks,
|
in_grid_step_hacks,
|
||||||
out_grid_iterator_hacks,
|
out_grid_step_hacks,
|
||||||
wei_grid_move_slice_window_iterator_hacks,
|
wei_grid_move_slice_window_step_hacks,
|
||||||
in_grid_move_slice_window_iterator_hacks,
|
in_grid_move_slice_window_step_hacks,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
|
|
||||||
float perf = (float)calculate_convolution_flops(
|
float perf = static_cast<float>(calculate_convolution_flops(
|
||||||
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo) /
|
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) /
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
#ifndef DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||||
#define DRIVER_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "gridwise_dynamic_contraction_dlops_v1r2.hpp"
|
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||||
|
|
||||||
template <ck::index_t BlockSize,
|
template <ck::index_t BlockSize,
|
||||||
typename FloatAB,
|
typename FloatAB,
|
||||||
@@ -39,24 +39,24 @@ template <ck::index_t BlockSize,
|
|||||||
typename CThreadTransferSrcDstAccessOrder,
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||||
ck::index_t CThreadTransferDstScalarPerVector,
|
ck::index_t CThreadTransferDstScalarPerVector,
|
||||||
typename AGridIteratorHacks,
|
typename AGridStepHacks,
|
||||||
typename BGridIteratorHacks,
|
typename BGridStepHacks,
|
||||||
typename CGridIteratorHacks,
|
typename CGridStepHacks,
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
__host__ float
|
__host__ float
|
||||||
driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||||
const FloatAB* p_b_grid,
|
const FloatAB* p_b_grid,
|
||||||
FloatC* p_c_grid,
|
FloatC* p_c_grid,
|
||||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
BGridMoveSliceWindowStepHacks,
|
||||||
ck::index_t nrepeat)
|
ck::index_t nrepeat)
|
||||||
|
|
||||||
{
|
{
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
@@ -70,7 +70,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
|
|
||||||
// GEMM
|
// GEMM
|
||||||
using GridwiseContraction =
|
using GridwiseContraction =
|
||||||
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
@@ -104,11 +104,11 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
CThreadTransferSrcDstAccessOrder,
|
CThreadTransferSrcDstAccessOrder,
|
||||||
CThreadTransferSrcDstVectorDim,
|
CThreadTransferSrcDstVectorDim,
|
||||||
CThreadTransferDstScalarPerVector,
|
CThreadTransferDstScalarPerVector,
|
||||||
AGridIteratorHacks,
|
AGridStepHacks,
|
||||||
BGridIteratorHacks,
|
BGridStepHacks,
|
||||||
CGridIteratorHacks,
|
CGridStepHacks,
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
AGridMoveSliceWindowStepHacks,
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
|
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
|
||||||
{
|
{
|
||||||
throw std::runtime_error("wrong! "
|
throw std::runtime_error("wrong! "
|
||||||
"GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
"GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
||||||
"GM0_GM1_GN0_GN1 has invalid setting");
|
"GM0_GM1_GN0_GN1 has invalid setting");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,7 +178,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
|
|
||||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
{
|
{
|
||||||
const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
|
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||||
GridwiseContraction,
|
GridwiseContraction,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatC,
|
FloatC,
|
||||||
@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
dim3(grid_size),
|
dim3(grid_size),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
p_a_grid,
|
p_a_grid,
|
||||||
p_b_grid,
|
p_b_grid,
|
||||||
p_c_grid,
|
p_c_grid,
|
||||||
@@ -205,7 +204,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
}
|
}
|
||||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||||
{
|
{
|
||||||
const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
|
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||||
GridwiseContraction,
|
GridwiseContraction,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatC,
|
FloatC,
|
||||||
@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
dim3(grid_size),
|
dim3(grid_size),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
p_a_grid,
|
p_a_grid,
|
||||||
p_b_grid,
|
p_b_grid,
|
||||||
p_c_grid,
|
p_c_grid,
|
||||||
@@ -232,7 +230,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
}
|
}
|
||||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
{
|
{
|
||||||
const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
|
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||||
GridwiseContraction,
|
GridwiseContraction,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatC,
|
FloatC,
|
||||||
@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
dim3(grid_size),
|
dim3(grid_size),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
p_a_grid,
|
p_a_grid,
|
||||||
p_b_grid,
|
p_b_grid,
|
||||||
p_c_grid,
|
p_c_grid,
|
||||||
@@ -259,7 +256,7 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const auto kernel = kernel_dynamic_contraction_dlops_v1r2<
|
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||||
GridwiseContraction,
|
GridwiseContraction,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatC,
|
FloatC,
|
||||||
@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
|||||||
dim3(grid_size),
|
dim3(grid_size),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
p_a_grid,
|
p_a_grid,
|
||||||
p_b_grid,
|
p_b_grid,
|
||||||
p_c_grid,
|
p_c_grid,
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||||
#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "gridwise_dynamic_gemm_dlops_v2.hpp"
|
#include "gridwise_gemm_dlops_v2.hpp"
|
||||||
#include "gridwise_operation_wrapper.hpp"
|
#include "gridwise_operation_wrapper.hpp"
|
||||||
|
|
||||||
template <ck::index_t BlockSize,
|
template <ck::index_t BlockSize,
|
||||||
@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
__host__ void Run(const ck::DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||||
const ck::DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||||
const ck::DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -82,14 +82,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
const auto InRightPadW = in_right_pads[I1];
|
const auto InRightPadW = in_right_pads[I1];
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hi_wi_global_desc,
|
in_n_c_hi_wi_global_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -98,7 +98,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hip_wip_global_desc,
|
in_n_c_hip_wip_global_desc,
|
||||||
make_tuple(
|
make_tuple(
|
||||||
make_pass_through_transform(N),
|
make_pass_through_transform(N),
|
||||||
@@ -108,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||||
|
|
||||||
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_y_ho_x_wo_global_desc,
|
in_n_c_y_ho_x_wo_global_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||||
make_pass_through_transform(N),
|
make_pass_through_transform(N),
|
||||||
@@ -118,8 +118,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
|
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||||
make_pass_through_transform(N),
|
make_pass_through_transform(N),
|
||||||
make_pass_through_transform(Ho),
|
make_pass_through_transform(Ho),
|
||||||
@@ -136,13 +136,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||||
constexpr auto a_e_k_global_iterator_hacks =
|
constexpr auto a_e_k_global_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
@@ -152,12 +152,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||||
|
|
||||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||||
// hack for NKHW format
|
// hack for NKHW format
|
||||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -169,7 +169,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
// GEMM
|
// GEMM
|
||||||
using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3<
|
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
@@ -202,11 +202,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
Sequence<0, 2, 3, 1>,
|
Sequence<0, 2, 3, 1>,
|
||||||
0,
|
0,
|
||||||
CThreadTransferDstScalarPerVector_W,
|
CThreadTransferDstScalarPerVector_W,
|
||||||
decltype(a_e_k_global_iterator_hacks),
|
decltype(a_e_k_global_step_hacks),
|
||||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||||
|
|
||||||
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
||||||
|
|
||||||
@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
dim3(GridSize),
|
dim3(GridSize),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
wei_e_k_global_desc,
|
wei_e_k_global_desc,
|
||||||
p_wei_global,
|
p_wei_global,
|
||||||
in_e_n_ho_wo_global_desc,
|
in_e_n_ho_wo_global_desc,
|
||||||
@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
dim3(GridSize),
|
dim3(GridSize),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
wei_e_k_global_desc,
|
wei_e_k_global_desc,
|
||||||
p_wei_global,
|
p_wei_global,
|
||||||
in_e_n_ho_wo_global_desc,
|
in_e_n_ho_wo_global_desc,
|
||||||
@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
dim3(GridSize),
|
dim3(GridSize),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
wei_e_k_global_desc,
|
wei_e_k_global_desc,
|
||||||
p_wei_global,
|
p_wei_global,
|
||||||
in_e_n_ho_wo_global_desc,
|
in_e_n_ho_wo_global_desc,
|
||||||
@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
dim3(GridSize),
|
dim3(GridSize),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
wei_e_k_global_desc,
|
wei_e_k_global_desc,
|
||||||
p_wei_global,
|
p_wei_global,
|
||||||
in_e_n_ho_wo_global_desc,
|
in_e_n_ho_wo_global_desc,
|
||||||
@@ -338,10 +334,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
|||||||
|
|
||||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||||
|
|
||||||
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
float perf =
|
||||||
wei_k_c_y_x_global_desc,
|
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||||
out_n_k0_ho_wo_k1_global_desc) /
|
wei_k_c_y_x_global_desc,
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
out_n_k0_ho_wo_k1_global_desc)) /
|
||||||
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
#ifndef DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||||
#define DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||||
|
|
||||||
#include "common_header.hpp"
|
#include "common_header.hpp"
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
#include "tensor_descriptor_helper.hpp"
|
||||||
#include "gridwise_dynamic_gemm_dlops_v2.hpp"
|
#include "gridwise_gemm_dlops_v2.hpp"
|
||||||
#include "gridwise_operation_wrapper.hpp"
|
#include "gridwise_operation_wrapper.hpp"
|
||||||
|
|
||||||
template <ck::index_t BlockSize,
|
template <ck::index_t BlockSize,
|
||||||
@@ -34,9 +34,9 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
typename ConvDilations,
|
typename ConvDilations,
|
||||||
typename InLeftPads,
|
typename InLeftPads,
|
||||||
typename InRightPads>
|
typename InRightPads>
|
||||||
__host__ void Run(const ck::DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||||
const ck::DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||||
const ck::DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
@@ -93,14 +93,14 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
|
||||||
// weight tensor
|
// weight tensor
|
||||||
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
|
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||||
|
|
||||||
// input tensor
|
// input tensor
|
||||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hi_wi_global_desc,
|
in_n_c_hi_wi_global_desc,
|
||||||
make_tuple(make_pass_through_transform(N),
|
make_tuple(make_pass_through_transform(N),
|
||||||
make_pass_through_transform(C),
|
make_pass_through_transform(C),
|
||||||
@@ -109,7 +109,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_hip_wip_global_desc,
|
in_n_c_hip_wip_global_desc,
|
||||||
make_tuple(
|
make_tuple(
|
||||||
make_pass_through_transform(N),
|
make_pass_through_transform(N),
|
||||||
@@ -119,7 +119,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||||
|
|
||||||
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||||
in_n_c_y_ho_x_wo_global_desc,
|
in_n_c_y_ho_x_wo_global_desc,
|
||||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||||
make_pass_through_transform(N),
|
make_pass_through_transform(N),
|
||||||
@@ -129,8 +129,8 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||||
|
|
||||||
// output tensor
|
// output tensor
|
||||||
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
|
const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor(
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
|
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||||
make_pass_through_transform(N),
|
make_pass_through_transform(N),
|
||||||
make_pad_transform(Ho, 0, OutRightPadH),
|
make_pad_transform(Ho, 0, OutRightPadH),
|
||||||
@@ -149,13 +149,13 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
}
|
}
|
||||||
|
|
||||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||||
constexpr auto a_e_k_global_iterator_hacks =
|
constexpr auto a_e_k_global_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
@@ -165,12 +165,12 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||||
|
|
||||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||||
// hack for NKHW format
|
// hack for NKHW format
|
||||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
Sequence<0, 0, 0, 0, 0>{},
|
Sequence<0, 0, 0, 0, 0>{},
|
||||||
@@ -181,7 +181,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
Sequence<0, 0, 0, 0, 0>{}));
|
Sequence<0, 0, 0, 0, 0>{}));
|
||||||
|
|
||||||
// GEMM
|
// GEMM
|
||||||
using gridwise_gemm = GridwiseDynamicGemmDlops_km_kn_mn_v3<
|
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||||
BlockSize,
|
BlockSize,
|
||||||
FloatAB,
|
FloatAB,
|
||||||
FloatAcc,
|
FloatAcc,
|
||||||
@@ -214,11 +214,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
Sequence<0, 2, 3, 1>,
|
Sequence<0, 2, 3, 1>,
|
||||||
0,
|
0,
|
||||||
CThreadTransferDstScalarPerVector_W,
|
CThreadTransferDstScalarPerVector_W,
|
||||||
decltype(a_e_k_global_iterator_hacks),
|
decltype(a_e_k_global_step_hacks),
|
||||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||||
|
|
||||||
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||||
|
|
||||||
@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
dim3(GridSize),
|
dim3(GridSize),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
wei_e_k_global_desc,
|
wei_e_k_global_desc,
|
||||||
p_wei_global,
|
p_wei_global,
|
||||||
in_e_n_ho_wo_global_desc,
|
in_e_n_ho_wo_global_desc,
|
||||||
@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
dim3(GridSize),
|
dim3(GridSize),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
wei_e_k_global_desc,
|
wei_e_k_global_desc,
|
||||||
p_wei_global,
|
p_wei_global,
|
||||||
in_e_n_ho_wo_global_desc,
|
in_e_n_ho_wo_global_desc,
|
||||||
@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
dim3(GridSize),
|
dim3(GridSize),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
wei_e_k_global_desc,
|
wei_e_k_global_desc,
|
||||||
p_wei_global,
|
p_wei_global,
|
||||||
in_e_n_ho_wo_global_desc,
|
in_e_n_ho_wo_global_desc,
|
||||||
@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
dim3(GridSize),
|
dim3(GridSize),
|
||||||
dim3(BlockSize),
|
dim3(BlockSize),
|
||||||
0,
|
0,
|
||||||
0,
|
|
||||||
wei_e_k_global_desc,
|
wei_e_k_global_desc,
|
||||||
p_wei_global,
|
p_wei_global,
|
||||||
in_e_n_ho_wo_global_desc,
|
in_e_n_ho_wo_global_desc,
|
||||||
@@ -354,10 +350,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
|
|||||||
|
|
||||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||||
|
|
||||||
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
float perf =
|
||||||
wei_k_c_y_x_global_desc,
|
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||||
out_n_k0_ho_wo_k1_global_desc) /
|
wei_k_c_y_x_global_desc,
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
out_n_k0_ho_wo_k1_global_desc)) /
|
||||||
|
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
@@ -1,415 +0,0 @@
|
|||||||
#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R2
|
|
||||||
#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R2
|
|
||||||
|
|
||||||
#include "common_header.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "gridwise_dynamic_gemm_dlops_v1r2.hpp"
|
|
||||||
|
|
||||||
template <ck::index_t BlockSize,
|
|
||||||
typename FloatAB,
|
|
||||||
typename FloatAcc,
|
|
||||||
typename FloatC,
|
|
||||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
|
||||||
typename AKMGridDesc,
|
|
||||||
typename BKNGridDesc,
|
|
||||||
typename CMNGridDesc,
|
|
||||||
ck::index_t MPerBlock,
|
|
||||||
ck::index_t NPerBlock,
|
|
||||||
ck::index_t KPerBlock,
|
|
||||||
ck::index_t M1PerThread,
|
|
||||||
ck::index_t N1PerThread,
|
|
||||||
ck::index_t KPerThread,
|
|
||||||
ck::index_t M1N1ThreadClusterM10,
|
|
||||||
ck::index_t M1N1ThreadClusterN10,
|
|
||||||
ck::index_t M1N1ThreadClusterM11,
|
|
||||||
ck::index_t M1N1ThreadClusterN11,
|
|
||||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
|
||||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
|
||||||
typename ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
typename ABlockTransferSrcAccessOrder,
|
|
||||||
ck::index_t ABlockTransferSrcVectorDim,
|
|
||||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
|
||||||
ck::index_t ABlockTransferDstScalarPerVector_M1,
|
|
||||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
|
||||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
|
||||||
typename BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
typename BBlockTransferSrcAccessOrder,
|
|
||||||
ck::index_t BBlockTransferSrcVectorDim,
|
|
||||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
|
||||||
ck::index_t BBlockTransferDstScalarPerVector_N1,
|
|
||||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
typename CThreadTransferSrcDstAccessOrder,
|
|
||||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
|
||||||
ck::index_t CThreadTransferDstScalarPerVector,
|
|
||||||
typename AGridIteratorHacks,
|
|
||||||
typename BGridIteratorHacks,
|
|
||||||
typename CGridIteratorHacks,
|
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
|
||||||
__host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
|
||||||
const FloatAB* p_b_grid,
|
|
||||||
FloatC* p_c_grid,
|
|
||||||
const AKMGridDesc& a_k_m_grid_desc,
|
|
||||||
const BKNGridDesc& b_k_n_grid_desc,
|
|
||||||
const CMNGridDesc& c_m_n_grid_desc,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
|
||||||
ck::index_t nrepeat)
|
|
||||||
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
|
|
||||||
// GEMM
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
CGlobalMemoryDataOperation,
|
|
||||||
AKMGridDesc,
|
|
||||||
BKNGridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
M1PerThread,
|
|
||||||
N1PerThread,
|
|
||||||
KPerThread,
|
|
||||||
M1N1ThreadClusterM10,
|
|
||||||
M1N1ThreadClusterN10,
|
|
||||||
M1N1ThreadClusterM11,
|
|
||||||
M1N1ThreadClusterN11,
|
|
||||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
|
||||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorDim,
|
|
||||||
ABlockTransferSrcScalarPerVector,
|
|
||||||
ABlockTransferDstScalarPerVector_M1,
|
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
|
||||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorDim,
|
|
||||||
BBlockTransferSrcScalarPerVector,
|
|
||||||
BBlockTransferDstScalarPerVector_N1,
|
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
|
||||||
|
|
||||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
|
||||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
|
||||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
|
||||||
|
|
||||||
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
|
|
||||||
{
|
|
||||||
throw std::runtime_error(
|
|
||||||
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r2 has invalid setting");
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
|
||||||
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
|
||||||
|
|
||||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
|
|
||||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
|
|
||||||
|
|
||||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
|
||||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
|
||||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
|
||||||
|
|
||||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
|
||||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
|
||||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
|
|
||||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
|
||||||
|
|
||||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
|
||||||
|
|
||||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
|
|
||||||
|
|
||||||
{
|
|
||||||
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
|
|
||||||
<< "}" << std::endl;
|
|
||||||
|
|
||||||
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
|
|
||||||
<< "}" << std::endl;
|
|
||||||
|
|
||||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
|
||||||
float ave_time = 0;
|
|
||||||
|
|
||||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AKM0M1GridDesc>,
|
|
||||||
remove_reference_t<BKN0N1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
true,
|
|
||||||
true>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k_m0_m1_grid_desc,
|
|
||||||
b_k_n0_n1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
}
|
|
||||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AKM0M1GridDesc>,
|
|
||||||
remove_reference_t<BKN0N1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
true,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k_m0_m1_grid_desc,
|
|
||||||
b_k_n0_n1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
}
|
|
||||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AKM0M1GridDesc>,
|
|
||||||
remove_reference_t<BKN0N1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
false,
|
|
||||||
true>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k_m0_m1_grid_desc,
|
|
||||||
b_k_n0_n1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AKM0M1GridDesc>,
|
|
||||||
remove_reference_t<BKN0N1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
false,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k_m0_m1_grid_desc,
|
|
||||||
b_k_n0_n1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ave_time;
|
|
||||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
|
||||||
DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc));
|
|
||||||
DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc));
|
|
||||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
|
||||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
|
||||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
|
||||||
|
|
||||||
a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc);
|
|
||||||
b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc);
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
|
||||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
|
|
||||||
float ave_time = 0;
|
|
||||||
|
|
||||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AKM0M1GridDesc>,
|
|
||||||
remove_reference_t<BKN0N1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
true,
|
|
||||||
true>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(
|
|
||||||
kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
}
|
|
||||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AKM0M1GridDesc>,
|
|
||||||
remove_reference_t<BKN0N1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
true,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(
|
|
||||||
kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
}
|
|
||||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AKM0M1GridDesc>,
|
|
||||||
remove_reference_t<BKN0N1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
false,
|
|
||||||
true>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(
|
|
||||||
kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r2<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AKM0M1GridDesc>,
|
|
||||||
remove_reference_t<BKN0N1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
false,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(
|
|
||||||
kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
}
|
|
||||||
|
|
||||||
return ave_time;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,411 +0,0 @@
|
|||||||
#ifndef DRIVER_DYNAMIC_GEMM_DLOPS_V1R3
|
|
||||||
#define DRIVER_DYNAMIC_GEMM_DLOPS_V1R3
|
|
||||||
|
|
||||||
#include "common_header.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "gridwise_dynamic_gemm_dlops_v1r3.hpp"
|
|
||||||
|
|
||||||
template <ck::index_t BlockSize,
|
|
||||||
typename FloatAB,
|
|
||||||
typename FloatAcc,
|
|
||||||
typename FloatC,
|
|
||||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
|
||||||
typename AK0MK1GridDesc,
|
|
||||||
typename BK0NK1GridDesc,
|
|
||||||
typename CMNGridDesc,
|
|
||||||
ck::index_t MPerBlock,
|
|
||||||
ck::index_t NPerBlock,
|
|
||||||
ck::index_t KPerBlock,
|
|
||||||
ck::index_t M1PerThread,
|
|
||||||
ck::index_t N1PerThread,
|
|
||||||
ck::index_t KPerThread,
|
|
||||||
typename M1N1ThreadClusterM1Xs,
|
|
||||||
typename M1N1ThreadClusterN1Xs,
|
|
||||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
|
||||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
|
||||||
typename ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
typename ABlockTransferSrcAccessOrder,
|
|
||||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
|
||||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
|
||||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
|
||||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
|
||||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
|
||||||
typename BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
typename BBlockTransferSrcAccessOrder,
|
|
||||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
|
||||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
|
||||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
|
||||||
typename CThreadTransferSrcDstAccessOrder,
|
|
||||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
|
||||||
ck::index_t CThreadTransferDstScalarPerVector,
|
|
||||||
typename AGridIteratorHacks,
|
|
||||||
typename BGridIteratorHacks,
|
|
||||||
typename CGridIteratorHacks,
|
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
|
||||||
typename BGridMoveSliceWindowIteratorHacks>
|
|
||||||
__host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
|
||||||
const FloatAB* p_b_grid,
|
|
||||||
FloatC* p_c_grid,
|
|
||||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
|
||||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
|
||||||
const CMNGridDesc& c_m_n_grid_desc,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
|
||||||
ck::index_t nrepeat)
|
|
||||||
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
|
|
||||||
// GEMM
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r3<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
CGlobalMemoryDataOperation,
|
|
||||||
AK0MK1GridDesc,
|
|
||||||
BK0NK1GridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
M1PerThread,
|
|
||||||
N1PerThread,
|
|
||||||
KPerThread,
|
|
||||||
M1N1ThreadClusterM1Xs,
|
|
||||||
M1N1ThreadClusterN1Xs,
|
|
||||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
|
||||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
|
||||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
|
||||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
|
||||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
|
||||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
|
||||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
|
||||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks>;
|
|
||||||
|
|
||||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
|
||||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
|
||||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
|
||||||
|
|
||||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
|
||||||
{
|
|
||||||
throw std::runtime_error(
|
|
||||||
"wrong! GridwiseDynamicGemmDlops_km_kn_mn_v1r3 has invalid setting");
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto a_k0_m0_m1_k1_grid_desc =
|
|
||||||
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
|
|
||||||
const auto b_k0_n0_n1_k1_grid_desc =
|
|
||||||
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
|
|
||||||
|
|
||||||
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
|
|
||||||
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
|
|
||||||
|
|
||||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
|
||||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
|
||||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
|
||||||
|
|
||||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
|
||||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
|
||||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
|
|
||||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
|
||||||
|
|
||||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
|
||||||
|
|
||||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
|
||||||
|
|
||||||
{
|
|
||||||
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
|
|
||||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
|
|
||||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
|
||||||
|
|
||||||
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
|
|
||||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
|
|
||||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
|
||||||
|
|
||||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
|
||||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
|
||||||
float ave_time = 0;
|
|
||||||
|
|
||||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
|
||||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
true,
|
|
||||||
true>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k0_m0_m1_k1_grid_desc,
|
|
||||||
b_k0_n0_n1_k1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
}
|
|
||||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
|
||||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
true,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k0_m0_m1_k1_grid_desc,
|
|
||||||
b_k0_n0_n1_k1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
}
|
|
||||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
|
||||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
false,
|
|
||||||
true>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k0_m0_m1_k1_grid_desc,
|
|
||||||
b_k0_n0_n1_k1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
|
||||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
false,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k0_m0_m1_k1_grid_desc,
|
|
||||||
b_k0_n0_n1_k1_grid_desc,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ave_time;
|
|
||||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
|
||||||
DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc));
|
|
||||||
DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc));
|
|
||||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
|
||||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
|
||||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
|
||||||
|
|
||||||
a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc);
|
|
||||||
b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc);
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
|
||||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
|
||||||
|
|
||||||
float ave_time = 0;
|
|
||||||
|
|
||||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
|
||||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
true,
|
|
||||||
true>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(
|
|
||||||
kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
}
|
|
||||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
|
||||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
true,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(
|
|
||||||
kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
}
|
|
||||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
|
||||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
false,
|
|
||||||
true>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(
|
|
||||||
kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const auto kernel =
|
|
||||||
kernel_dynamic_gemm_dlops_v1r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
|
||||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
|
||||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
|
||||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
|
||||||
false,
|
|
||||||
false>;
|
|
||||||
|
|
||||||
ave_time = launch_and_time_kernel(
|
|
||||||
kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
}
|
|
||||||
|
|
||||||
return ave_time;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
#ifndef DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
|
|
||||||
#define DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
|
|
||||||
|
|
||||||
#include "common_header.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
|
||||||
|
|
||||||
template <ck::index_t BlockSize,
|
|
||||||
typename FloatAB,
|
|
||||||
typename FloatAcc,
|
|
||||||
typename FloatC,
|
|
||||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
|
||||||
typename AK0MK1GridDesc,
|
|
||||||
typename BK0NK1GridDesc,
|
|
||||||
typename CMNGridDesc,
|
|
||||||
ck::index_t MPerBlock,
|
|
||||||
ck::index_t NPerBlock,
|
|
||||||
ck::index_t KPerBlock,
|
|
||||||
ck::index_t MPerWave,
|
|
||||||
ck::index_t NPerWave,
|
|
||||||
ck::index_t K1,
|
|
||||||
ck::index_t MRepeat,
|
|
||||||
ck::index_t NRepeat,
|
|
||||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
|
||||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
|
||||||
typename ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
typename ABlockTransferSrcAccessOrder,
|
|
||||||
ck::index_t ABlockTransferSrcVectorDim,
|
|
||||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
|
||||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
|
||||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
|
||||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
|
||||||
typename BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
typename BBlockTransferSrcAccessOrder,
|
|
||||||
ck::index_t BBlockTransferSrcVectorDim,
|
|
||||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
|
||||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
|
||||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
typename CThreadTransferSrcDstAccessOrder,
|
|
||||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
|
||||||
ck::index_t CThreadTransferDstScalarPerVector,
|
|
||||||
typename AGridIteratorHacks,
|
|
||||||
typename BGridIteratorHacks,
|
|
||||||
typename CGridIteratorHacks,
|
|
||||||
typename AGridMoveSliceWindowIteratorHacks,
|
|
||||||
typename BGridMoveSliceWindowIteratorHacks,
|
|
||||||
bool CAccessOrderMRepeatNRepeat>
|
|
||||||
__host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
|
||||||
const FloatAB* p_b_grid,
|
|
||||||
FloatC* p_c_grid,
|
|
||||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
|
||||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
|
||||||
const CMNGridDesc& c_m_n_grid_desc,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
|
||||||
ck::index_t nrepeat)
|
|
||||||
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
|
|
||||||
using GridwiseGemm =
|
|
||||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
|
||||||
FloatAB,
|
|
||||||
FloatAcc,
|
|
||||||
FloatC,
|
|
||||||
CGlobalMemoryDataOperation,
|
|
||||||
AK0MK1GridDesc,
|
|
||||||
BK0NK1GridDesc,
|
|
||||||
CMNGridDesc,
|
|
||||||
MPerBlock,
|
|
||||||
NPerBlock,
|
|
||||||
KPerBlock,
|
|
||||||
MPerWave,
|
|
||||||
NPerWave,
|
|
||||||
K1,
|
|
||||||
MRepeat,
|
|
||||||
NRepeat,
|
|
||||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
|
||||||
ABlockTransferThreadClusterArrangeOrder,
|
|
||||||
ABlockTransferSrcAccessOrder,
|
|
||||||
ABlockTransferSrcVectorDim,
|
|
||||||
ABlockTransferSrcScalarPerVector,
|
|
||||||
ABlockTransferDstScalarPerVector_K1,
|
|
||||||
AThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
|
||||||
BBlockTransferThreadClusterArrangeOrder,
|
|
||||||
BBlockTransferSrcAccessOrder,
|
|
||||||
BBlockTransferSrcVectorDim,
|
|
||||||
BBlockTransferSrcScalarPerVector,
|
|
||||||
BBlockTransferDstScalarPerVector_K1,
|
|
||||||
BThreadTransferSrcResetCoordinateAfterRun,
|
|
||||||
CThreadTransferSrcDstAccessOrder,
|
|
||||||
CThreadTransferSrcDstVectorDim,
|
|
||||||
CThreadTransferDstScalarPerVector,
|
|
||||||
AGridIteratorHacks,
|
|
||||||
BGridIteratorHacks,
|
|
||||||
CGridIteratorHacks,
|
|
||||||
AGridMoveSliceWindowIteratorHacks,
|
|
||||||
BGridMoveSliceWindowIteratorHacks,
|
|
||||||
CAccessOrderMRepeatNRepeat>;
|
|
||||||
|
|
||||||
{
|
|
||||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
|
|
||||||
<< "}" << std::endl;
|
|
||||||
|
|
||||||
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
|
|
||||||
<< "}" << std::endl;
|
|
||||||
|
|
||||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
|
||||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
|
||||||
{
|
|
||||||
throw std::runtime_error(
|
|
||||||
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
|
|
||||||
|
|
||||||
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
|
||||||
|
|
||||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
|
||||||
|
|
||||||
const auto kernel = kernel_dynamic_gemm_xdlops_v2r3<GridwiseGemm,
|
|
||||||
FloatAB,
|
|
||||||
FloatC,
|
|
||||||
remove_reference_t<AK0MK1GridDesc>,
|
|
||||||
remove_reference_t<BK0NK1GridDesc>,
|
|
||||||
remove_reference_t<CM0M1M2NGridDesc>,
|
|
||||||
remove_reference_t<CBlockClusterAdaptor>>;
|
|
||||||
|
|
||||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
|
||||||
float ave_time = launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
a_k0_m_k1_grid_desc,
|
|
||||||
b_k0_n_k1_grid_desc,
|
|
||||||
c_m0_m1_m2_n_grid_desc,
|
|
||||||
c_block_cluster_adaptor);
|
|
||||||
|
|
||||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
|
||||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
|
||||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
|
||||||
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
|
|
||||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
|
||||||
|
|
||||||
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
|
||||||
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
|
||||||
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
|
||||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
|
||||||
|
|
||||||
float ave_time =
|
|
||||||
launch_and_time_kernel(kernel,
|
|
||||||
nrepeat,
|
|
||||||
dim3(grid_size),
|
|
||||||
dim3(BlockSize),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
p_a_grid,
|
|
||||||
p_b_grid,
|
|
||||||
p_c_grid,
|
|
||||||
(void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
|
|
||||||
(void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
|
||||||
#endif
|
|
||||||
return ave_time;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
413
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
Normal file
413
host/driver_offline/include/driver_gemm_dlops_v1r2.hpp
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
#ifndef DRIVER_GEMM_DLOPS_V1R2
|
||||||
|
#define DRIVER_GEMM_DLOPS_V1R2
|
||||||
|
|
||||||
|
#include "common_header.hpp"
|
||||||
|
#include "tensor_descriptor.hpp"
|
||||||
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
#include "gridwise_gemm_dlops_v1r2.hpp"
|
||||||
|
|
||||||
|
template <ck::index_t BlockSize,
|
||||||
|
typename FloatAB,
|
||||||
|
typename FloatAcc,
|
||||||
|
typename FloatC,
|
||||||
|
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||||
|
typename AKMGridDesc,
|
||||||
|
typename BKNGridDesc,
|
||||||
|
typename CMNGridDesc,
|
||||||
|
ck::index_t MPerBlock,
|
||||||
|
ck::index_t NPerBlock,
|
||||||
|
ck::index_t KPerBlock,
|
||||||
|
ck::index_t M1PerThread,
|
||||||
|
ck::index_t N1PerThread,
|
||||||
|
ck::index_t KPerThread,
|
||||||
|
ck::index_t M1N1ThreadClusterM10,
|
||||||
|
ck::index_t M1N1ThreadClusterN10,
|
||||||
|
ck::index_t M1N1ThreadClusterM11,
|
||||||
|
ck::index_t M1N1ThreadClusterN11,
|
||||||
|
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||||
|
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||||
|
typename ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
typename ABlockTransferSrcAccessOrder,
|
||||||
|
ck::index_t ABlockTransferSrcVectorDim,
|
||||||
|
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||||
|
ck::index_t ABlockTransferDstScalarPerVector_M1,
|
||||||
|
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||||
|
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||||
|
typename BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
typename BBlockTransferSrcAccessOrder,
|
||||||
|
ck::index_t BBlockTransferSrcVectorDim,
|
||||||
|
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||||
|
ck::index_t BBlockTransferDstScalarPerVector_N1,
|
||||||
|
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
|
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||||
|
ck::index_t CThreadTransferDstScalarPerVector,
|
||||||
|
typename AGridStepHacks,
|
||||||
|
typename BGridStepHacks,
|
||||||
|
typename CGridStepHacks,
|
||||||
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
|
__host__ float driver_gemm_dlops_v1r2(const FloatAB* p_a_grid,
|
||||||
|
const FloatAB* p_b_grid,
|
||||||
|
FloatC* p_c_grid,
|
||||||
|
const AKMGridDesc& a_k_m_grid_desc,
|
||||||
|
const BKNGridDesc& b_k_n_grid_desc,
|
||||||
|
const CMNGridDesc& c_m_n_grid_desc,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks,
|
||||||
|
ck::index_t nrepeat)
|
||||||
|
|
||||||
|
{
|
||||||
|
using namespace ck;
|
||||||
|
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
constexpr auto I3 = Number<3>{};
|
||||||
|
constexpr auto I4 = Number<4>{};
|
||||||
|
constexpr auto I5 = Number<5>{};
|
||||||
|
|
||||||
|
// GEMM
|
||||||
|
using GridwiseGemm = GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
CGlobalMemoryDataOperation,
|
||||||
|
AKMGridDesc,
|
||||||
|
BKNGridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
M1PerThread,
|
||||||
|
N1PerThread,
|
||||||
|
KPerThread,
|
||||||
|
M1N1ThreadClusterM10,
|
||||||
|
M1N1ThreadClusterN10,
|
||||||
|
M1N1ThreadClusterM11,
|
||||||
|
M1N1ThreadClusterN11,
|
||||||
|
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||||
|
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_M1,
|
||||||
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||||
|
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_N1,
|
||||||
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
|
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||||
|
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||||
|
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||||
|
|
||||||
|
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
|
||||||
|
{
|
||||||
|
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r2 has invalid setting");
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||||
|
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||||
|
|
||||||
|
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
|
||||||
|
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
|
||||||
|
|
||||||
|
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||||
|
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||||
|
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||||
|
|
||||||
|
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||||
|
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
|
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
|
||||||
|
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||||
|
|
||||||
|
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||||
|
|
||||||
|
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
|
||||||
|
<< "}" << std::endl;
|
||||||
|
|
||||||
|
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
|
||||||
|
<< "}" << std::endl;
|
||||||
|
|
||||||
|
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||||
|
float ave_time = 0;
|
||||||
|
|
||||||
|
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AKM0M1GridDesc>,
|
||||||
|
remove_reference_t<BKN0N1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
true,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k_m0_m1_grid_desc,
|
||||||
|
b_k_n0_n1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
}
|
||||||
|
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AKM0M1GridDesc>,
|
||||||
|
remove_reference_t<BKN0N1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
true,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k_m0_m1_grid_desc,
|
||||||
|
b_k_n0_n1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
}
|
||||||
|
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AKM0M1GridDesc>,
|
||||||
|
remove_reference_t<BKN0N1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
false,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k_m0_m1_grid_desc,
|
||||||
|
b_k_n0_n1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AKM0M1GridDesc>,
|
||||||
|
remove_reference_t<BKN0N1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
false,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k_m0_m1_grid_desc,
|
||||||
|
b_k_n0_n1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ave_time;
|
||||||
|
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||||
|
DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc));
|
||||||
|
DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc));
|
||||||
|
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||||
|
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||||
|
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||||
|
|
||||||
|
a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc);
|
||||||
|
b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc);
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||||
|
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
|
||||||
|
float ave_time = 0;
|
||||||
|
|
||||||
|
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AKM0M1GridDesc>,
|
||||||
|
remove_reference_t<BKN0N1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
true,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
}
|
||||||
|
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AKM0M1GridDesc>,
|
||||||
|
remove_reference_t<BKN0N1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
true,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
}
|
||||||
|
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AKM0M1GridDesc>,
|
||||||
|
remove_reference_t<BKN0N1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
false,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r2<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AKM0M1GridDesc>,
|
||||||
|
remove_reference_t<BKN0N1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
false,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ave_time;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
#endif
|
||||||
418
host/driver_offline/include/driver_gemm_dlops_v1r3.hpp
Normal file
418
host/driver_offline/include/driver_gemm_dlops_v1r3.hpp
Normal file
@@ -0,0 +1,418 @@
|
|||||||
|
#ifndef DRIVER_GEMM_DLOPS_V1R3
|
||||||
|
#define DRIVER_GEMM_DLOPS_V1R3
|
||||||
|
|
||||||
|
#include "common_header.hpp"
|
||||||
|
#include "tensor_descriptor.hpp"
|
||||||
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
#include "gridwise_gemm_dlops_v1r3.hpp"
|
||||||
|
|
||||||
|
template <ck::index_t BlockSize,
|
||||||
|
typename FloatAB,
|
||||||
|
typename FloatAcc,
|
||||||
|
typename FloatC,
|
||||||
|
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||||
|
typename AK0MK1GridDesc,
|
||||||
|
typename BK0NK1GridDesc,
|
||||||
|
typename CMNGridDesc,
|
||||||
|
ck::index_t MPerBlock,
|
||||||
|
ck::index_t NPerBlock,
|
||||||
|
ck::index_t KPerBlock,
|
||||||
|
ck::index_t M1PerThread,
|
||||||
|
ck::index_t N1PerThread,
|
||||||
|
ck::index_t KPerThread,
|
||||||
|
typename M1N1ThreadClusterM1Xs,
|
||||||
|
typename M1N1ThreadClusterN1Xs,
|
||||||
|
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||||
|
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||||
|
typename ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
typename ABlockTransferSrcAccessOrder,
|
||||||
|
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||||
|
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||||
|
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||||
|
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||||
|
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||||
|
typename BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
typename BBlockTransferSrcAccessOrder,
|
||||||
|
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||||
|
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||||
|
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||||
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
|
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||||
|
ck::index_t CThreadTransferDstScalarPerVector,
|
||||||
|
typename AGridStepHacks,
|
||||||
|
typename BGridStepHacks,
|
||||||
|
typename CGridStepHacks,
|
||||||
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
|
typename BGridMoveSliceWindowStepHacks>
|
||||||
|
__host__ float driver_gemm_dlops_v1r3(const FloatAB* p_a_grid,
|
||||||
|
const FloatAB* p_b_grid,
|
||||||
|
FloatC* p_c_grid,
|
||||||
|
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||||
|
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||||
|
const CMNGridDesc& c_m_n_grid_desc,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks,
|
||||||
|
ck::index_t nrepeat)
|
||||||
|
|
||||||
|
{
|
||||||
|
using namespace ck;
|
||||||
|
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
constexpr auto I3 = Number<3>{};
|
||||||
|
constexpr auto I4 = Number<4>{};
|
||||||
|
constexpr auto I5 = Number<5>{};
|
||||||
|
|
||||||
|
// GEMM
|
||||||
|
using GridwiseGemm =
|
||||||
|
GridwiseGemmDlops_km_kn_mn_v1r3<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
CGlobalMemoryDataOperation,
|
||||||
|
AK0MK1GridDesc,
|
||||||
|
BK0NK1GridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
M1PerThread,
|
||||||
|
N1PerThread,
|
||||||
|
KPerThread,
|
||||||
|
M1N1ThreadClusterM1Xs,
|
||||||
|
M1N1ThreadClusterN1Xs,
|
||||||
|
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||||
|
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||||
|
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||||
|
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||||
|
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||||
|
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||||
|
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||||
|
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks>;
|
||||||
|
|
||||||
|
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||||
|
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||||
|
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||||
|
|
||||||
|
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||||
|
{
|
||||||
|
throw std::runtime_error("wrong! GridwiseGemmDlops_km_kn_mn_v1r3 has invalid setting");
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto a_k0_m0_m1_k1_grid_desc =
|
||||||
|
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
|
||||||
|
const auto b_k0_n0_n1_k1_grid_desc =
|
||||||
|
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
|
||||||
|
|
||||||
|
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
|
||||||
|
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
|
||||||
|
|
||||||
|
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||||
|
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||||
|
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||||
|
|
||||||
|
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||||
|
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||||
|
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
|
||||||
|
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||||
|
|
||||||
|
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||||
|
|
||||||
|
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
|
||||||
|
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
|
||||||
|
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||||
|
|
||||||
|
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
|
||||||
|
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
|
||||||
|
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||||
|
|
||||||
|
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||||
|
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||||
|
float ave_time = 0;
|
||||||
|
|
||||||
|
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||||
|
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
true,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k0_m0_m1_k1_grid_desc,
|
||||||
|
b_k0_n0_n1_k1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
}
|
||||||
|
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||||
|
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
true,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k0_m0_m1_k1_grid_desc,
|
||||||
|
b_k0_n0_n1_k1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
}
|
||||||
|
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||||
|
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
false,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k0_m0_m1_k1_grid_desc,
|
||||||
|
b_k0_n0_n1_k1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||||
|
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
false,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k0_m0_m1_k1_grid_desc,
|
||||||
|
b_k0_n0_n1_k1_grid_desc,
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ave_time;
|
||||||
|
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||||
|
DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc));
|
||||||
|
DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc));
|
||||||
|
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||||
|
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||||
|
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||||
|
|
||||||
|
a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc);
|
||||||
|
b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc);
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||||
|
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||||
|
|
||||||
|
float ave_time = 0;
|
||||||
|
|
||||||
|
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||||
|
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
true,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
}
|
||||||
|
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||||
|
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
true,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
}
|
||||||
|
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||||
|
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
false,
|
||||||
|
true>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const auto kernel =
|
||||||
|
kernel_gemm_dlops_v1r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||||
|
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||||
|
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||||
|
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||||
|
false,
|
||||||
|
false>;
|
||||||
|
|
||||||
|
ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(
|
||||||
|
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ave_time;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
#endif
|
||||||
191
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
Normal file
191
host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
#ifndef DRIVER_GEMM_XDLOPS_V2R3
|
||||||
|
#define DRIVER_GEMM_XDLOPS_V2R3
|
||||||
|
|
||||||
|
#include "common_header.hpp"
|
||||||
|
#include "tensor_descriptor.hpp"
|
||||||
|
#include "tensor_descriptor_helper.hpp"
|
||||||
|
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||||
|
|
||||||
|
template <ck::index_t BlockSize,
|
||||||
|
typename FloatAB,
|
||||||
|
typename FloatAcc,
|
||||||
|
typename FloatC,
|
||||||
|
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||||
|
typename AK0MK1GridDesc,
|
||||||
|
typename BK0NK1GridDesc,
|
||||||
|
typename CMNGridDesc,
|
||||||
|
ck::index_t MPerBlock,
|
||||||
|
ck::index_t NPerBlock,
|
||||||
|
ck::index_t KPerBlock,
|
||||||
|
ck::index_t MPerWave,
|
||||||
|
ck::index_t NPerWave,
|
||||||
|
ck::index_t K1,
|
||||||
|
ck::index_t MRepeat,
|
||||||
|
ck::index_t NRepeat,
|
||||||
|
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||||
|
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||||
|
typename ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
typename ABlockTransferSrcAccessOrder,
|
||||||
|
ck::index_t ABlockTransferSrcVectorDim,
|
||||||
|
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||||
|
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||||
|
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||||
|
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||||
|
typename BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
typename BBlockTransferSrcAccessOrder,
|
||||||
|
ck::index_t BBlockTransferSrcVectorDim,
|
||||||
|
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||||
|
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||||
|
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
typename CThreadTransferSrcDstAccessOrder,
|
||||||
|
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||||
|
ck::index_t CThreadTransferDstScalarPerVector,
|
||||||
|
typename AGridStepHacks,
|
||||||
|
typename BGridStepHacks,
|
||||||
|
typename CGridStepHacks,
|
||||||
|
typename AGridMoveSliceWindowStepHacks,
|
||||||
|
typename BGridMoveSliceWindowStepHacks,
|
||||||
|
bool CAccessOrderMRepeatNRepeat>
|
||||||
|
__host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||||
|
const FloatAB* p_b_grid,
|
||||||
|
FloatC* p_c_grid,
|
||||||
|
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||||
|
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||||
|
const CMNGridDesc& c_m_n_grid_desc,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks,
|
||||||
|
ck::index_t nrepeat)
|
||||||
|
|
||||||
|
{
|
||||||
|
using namespace ck;
|
||||||
|
|
||||||
|
constexpr auto I0 = Number<0>{};
|
||||||
|
constexpr auto I1 = Number<1>{};
|
||||||
|
constexpr auto I2 = Number<2>{};
|
||||||
|
|
||||||
|
using GridwiseGemm =
|
||||||
|
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||||
|
FloatAB,
|
||||||
|
FloatAcc,
|
||||||
|
FloatC,
|
||||||
|
CGlobalMemoryDataOperation,
|
||||||
|
AK0MK1GridDesc,
|
||||||
|
BK0NK1GridDesc,
|
||||||
|
CMNGridDesc,
|
||||||
|
MPerBlock,
|
||||||
|
NPerBlock,
|
||||||
|
KPerBlock,
|
||||||
|
MPerWave,
|
||||||
|
NPerWave,
|
||||||
|
K1,
|
||||||
|
MRepeat,
|
||||||
|
NRepeat,
|
||||||
|
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||||
|
ABlockTransferThreadClusterArrangeOrder,
|
||||||
|
ABlockTransferSrcAccessOrder,
|
||||||
|
ABlockTransferSrcVectorDim,
|
||||||
|
ABlockTransferSrcScalarPerVector,
|
||||||
|
ABlockTransferDstScalarPerVector_K1,
|
||||||
|
AThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||||
|
BBlockTransferThreadClusterArrangeOrder,
|
||||||
|
BBlockTransferSrcAccessOrder,
|
||||||
|
BBlockTransferSrcVectorDim,
|
||||||
|
BBlockTransferSrcScalarPerVector,
|
||||||
|
BBlockTransferDstScalarPerVector_K1,
|
||||||
|
BThreadTransferSrcResetCoordinateAfterRun,
|
||||||
|
CThreadTransferSrcDstAccessOrder,
|
||||||
|
CThreadTransferSrcDstVectorDim,
|
||||||
|
CThreadTransferDstScalarPerVector,
|
||||||
|
AGridStepHacks,
|
||||||
|
BGridStepHacks,
|
||||||
|
CGridStepHacks,
|
||||||
|
AGridMoveSliceWindowStepHacks,
|
||||||
|
BGridMoveSliceWindowStepHacks,
|
||||||
|
CAccessOrderMRepeatNRepeat>;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
|
||||||
|
<< "}" << std::endl;
|
||||||
|
|
||||||
|
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
|
||||||
|
<< "}" << std::endl;
|
||||||
|
|
||||||
|
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||||
|
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||||
|
{
|
||||||
|
throw std::runtime_error(
|
||||||
|
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
|
||||||
|
|
||||||
|
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||||
|
|
||||||
|
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||||
|
|
||||||
|
const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm,
|
||||||
|
FloatAB,
|
||||||
|
FloatC,
|
||||||
|
remove_reference_t<AK0MK1GridDesc>,
|
||||||
|
remove_reference_t<BK0NK1GridDesc>,
|
||||||
|
remove_reference_t<CM0M1M2NGridDesc>,
|
||||||
|
remove_reference_t<CBlockClusterAdaptor>>;
|
||||||
|
|
||||||
|
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||||
|
float ave_time = launch_and_time_kernel(kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
a_k0_m_k1_grid_desc,
|
||||||
|
b_k0_n_k1_grid_desc,
|
||||||
|
c_m0_m1_m2_n_grid_desc,
|
||||||
|
c_block_cluster_adaptor);
|
||||||
|
|
||||||
|
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||||
|
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||||
|
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||||
|
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
|
||||||
|
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||||
|
|
||||||
|
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
||||||
|
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
||||||
|
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
||||||
|
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||||
|
|
||||||
|
float ave_time = launch_and_time_kernel(
|
||||||
|
kernel,
|
||||||
|
nrepeat,
|
||||||
|
dim3(grid_size),
|
||||||
|
dim3(BlockSize),
|
||||||
|
0,
|
||||||
|
p_a_grid,
|
||||||
|
p_b_grid,
|
||||||
|
p_c_grid,
|
||||||
|
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()),
|
||||||
|
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
|
||||||
|
#endif
|
||||||
|
return ave_time;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -12,10 +12,10 @@
|
|||||||
#include "conv_common.hpp"
|
#include "conv_common.hpp"
|
||||||
#include "host_conv_bwd_data.hpp"
|
#include "host_conv_bwd_data.hpp"
|
||||||
#include "device_tensor.hpp"
|
#include "device_tensor.hpp"
|
||||||
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
#include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||||
#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||||
|
|
||||||
#define USE_DYNAMIC_MODE 1
|
#define USE_MODE 1
|
||||||
#define USE_CONV_BWD_V4R1_XDL_NHWC 1
|
#define USE_CONV_BWD_V4R1_XDL_NHWC 1
|
||||||
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
|
#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ int main(int argc, char* argv[])
|
|||||||
constexpr auto I5 = Number<5>{};
|
constexpr auto I5 = Number<5>{};
|
||||||
constexpr auto I6 = Number<6>{};
|
constexpr auto I6 = Number<6>{};
|
||||||
|
|
||||||
#if USE_DYNAMIC_MODE
|
#if USE_MODE
|
||||||
// dynamic mode
|
// dynamic mode
|
||||||
if(argc != 22)
|
if(argc != 22)
|
||||||
{
|
{
|
||||||
@@ -46,29 +46,29 @@ int main(int argc, char* argv[])
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(atoi(argv[2]));
|
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(std::stoi(argv[2]));
|
||||||
const bool do_verification = atoi(argv[3]);
|
const bool do_verification = std::stoi(argv[3]);
|
||||||
const int init_method = atoi(argv[4]);
|
const int init_method = std::stoi(argv[4]);
|
||||||
const bool do_log = atoi(argv[5]);
|
const bool do_log = std::stoi(argv[5]);
|
||||||
const int nrepeat = atoi(argv[6]);
|
const int nrepeat = std::stoi(argv[6]);
|
||||||
|
|
||||||
const index_t N = atoi(argv[7]);
|
const index_t N = std::stoi(argv[7]);
|
||||||
const index_t K = atoi(argv[8]);
|
const index_t K = std::stoi(argv[8]);
|
||||||
const index_t C = atoi(argv[9]);
|
const index_t C = std::stoi(argv[9]);
|
||||||
const index_t Y = atoi(argv[10]);
|
const index_t Y = std::stoi(argv[10]);
|
||||||
const index_t X = atoi(argv[11]);
|
const index_t X = std::stoi(argv[11]);
|
||||||
const index_t Hi = atoi(argv[12]);
|
const index_t Hi = std::stoi(argv[12]);
|
||||||
const index_t Wi = atoi(argv[13]);
|
const index_t Wi = std::stoi(argv[13]);
|
||||||
|
|
||||||
const index_t conv_stride_h = atoi(argv[14]);
|
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||||
const index_t conv_stride_w = atoi(argv[15]);
|
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||||
const index_t conv_dilation_h = atoi(argv[16]);
|
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||||
const index_t conv_dilation_w = atoi(argv[17]);
|
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||||
const index_t in_left_pad_h = atoi(argv[18]);
|
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||||
const index_t in_left_pad_w = atoi(argv[19]);
|
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||||
const index_t in_right_pad_h = atoi(argv[20]);
|
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||||
const index_t in_right_pad_w = atoi(argv[21]);
|
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||||
|
|
||||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||||
@@ -83,12 +83,12 @@ int main(int argc, char* argv[])
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||||
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(atoi(argv[2]));
|
const ConvBackwardDataAlgo algo = static_cast<ConvBackwardDataAlgo>(std::stoi(argv[2]));
|
||||||
const bool do_verification = atoi(argv[3]);
|
const bool do_verification = std::stoi(argv[3]);
|
||||||
const int init_method = atoi(argv[4]);
|
const int init_method = std::stoi(argv[4]);
|
||||||
const bool do_log = atoi(argv[5]);
|
const bool do_log = std::stoi(argv[5]);
|
||||||
const int nrepeat = atoi(argv[6]);
|
const int nrepeat = std::stoi(argv[6]);
|
||||||
|
|
||||||
constexpr index_t N = 128;
|
constexpr index_t N = 128;
|
||||||
constexpr index_t C = 192;
|
constexpr index_t C = 192;
|
||||||
@@ -115,23 +115,19 @@ int main(int argc, char* argv[])
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
constexpr index_t in_vector_size = 1;
|
|
||||||
using in_data_t = float;
|
using in_data_t = float;
|
||||||
using acc_data_t = float;
|
using acc_data_t = float;
|
||||||
using out_data_t = float;
|
using out_data_t = float;
|
||||||
#elif 1
|
#elif 1
|
||||||
constexpr index_t in_vector_size = 1;
|
using in_data_t = half_t;
|
||||||
using in_data_t = half_t;
|
using acc_data_t = float;
|
||||||
using acc_data_t = float;
|
using out_data_t = half_t;
|
||||||
using out_data_t = half_t;
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||||
|
|
||||||
switch(layout)
|
if(layout == ConvTensorLayout::NCHW)
|
||||||
{
|
{
|
||||||
case ConvTensorLayout::NCHW:
|
|
||||||
// NCHW
|
|
||||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||||
@@ -144,9 +140,9 @@ int main(int argc, char* argv[])
|
|||||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||||
break;
|
}
|
||||||
case ConvTensorLayout::NHWC:
|
else if(layout == ConvTensorLayout::NHWC)
|
||||||
// NHWC
|
{
|
||||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||||
@@ -159,8 +155,10 @@ int main(int argc, char* argv[])
|
|||||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||||
break;
|
}
|
||||||
default: throw std::runtime_error("wrong! not implemented");
|
else
|
||||||
|
{
|
||||||
|
throw std::runtime_error("wrong! not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor<in_data_t> in_host(in_lengths_host);
|
Tensor<in_data_t> in_host(in_lengths_host);
|
||||||
@@ -213,40 +211,8 @@ int main(int argc, char* argv[])
|
|||||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto f_make_for_device_nchw = [&]() {
|
|
||||||
#if USE_DYNAMIC_MODE
|
|
||||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
|
||||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
|
||||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
|
||||||
const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w);
|
|
||||||
const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w);
|
|
||||||
const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w);
|
|
||||||
const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w);
|
|
||||||
#else
|
|
||||||
const auto in_lengths_dev =
|
|
||||||
make_tuple(Number<N>{}, Number<C>{}, Number<Hi>{}, Number<Wi>{});
|
|
||||||
const auto wei_lengths_dev = make_tuple(Number<K>{}, Number<C>{}, Number<Y>{}, Number<X>{});
|
|
||||||
const auto out_lengths_dev =
|
|
||||||
make_tuple(Number<N>{}, Number<K>{}, Number<Ho>{}, Number<Wo>{});
|
|
||||||
const auto conv_strides_dev = make_tuple(Number<conv_stride_h>{}, Number<conv_stride_w>{});
|
|
||||||
const auto conv_dilations_dev =
|
|
||||||
make_tuple(Number<conv_dilation_h>{}, Number<conv_dilation_w>{});
|
|
||||||
const auto in_left_pads_dev = make_tuple(Number<in_left_pad_h>{}, Number<in_left_pad_w>{});
|
|
||||||
const auto in_right_pads_dev =
|
|
||||||
make_tuple(Number<in_right_pad_h>{}, Number<in_right_pad_w>{});
|
|
||||||
#endif
|
|
||||||
|
|
||||||
return make_tuple(in_lengths_dev,
|
|
||||||
wei_lengths_dev,
|
|
||||||
out_lengths_dev,
|
|
||||||
conv_strides_dev,
|
|
||||||
conv_dilations_dev,
|
|
||||||
in_left_pads_dev,
|
|
||||||
in_right_pads_dev);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto f_make_for_device_nhwc = [&]() {
|
auto f_make_for_device_nhwc = [&]() {
|
||||||
#if USE_DYNAMIC_MODE
|
#if USE_MODE
|
||||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||||
@@ -277,8 +243,6 @@ int main(int argc, char* argv[])
|
|||||||
in_right_pads_dev);
|
in_right_pads_dev);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto nhwc_desc = f_make_for_device_nhwc();
|
|
||||||
|
|
||||||
#if USE_CONV_BWD_V4R1_XDL_NHWC
|
#if USE_CONV_BWD_V4R1_XDL_NHWC
|
||||||
if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC)
|
if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC)
|
||||||
{
|
{
|
||||||
@@ -289,20 +253,20 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
const auto tmp = f_make_for_device_nhwc();
|
const auto tmp = f_make_for_device_nhwc();
|
||||||
|
|
||||||
device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk<
|
device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||||
in_data_t,
|
acc_data_t,
|
||||||
acc_data_t,
|
out_data_t>(
|
||||||
out_data_t>(tmp[I0],
|
tmp[I0],
|
||||||
tmp[I1],
|
tmp[I1],
|
||||||
tmp[I2],
|
tmp[I2],
|
||||||
tmp[I3],
|
tmp[I3],
|
||||||
tmp[I4],
|
tmp[I4],
|
||||||
tmp[I5],
|
tmp[I5],
|
||||||
tmp[I6],
|
tmp[I6],
|
||||||
in_device,
|
in_device,
|
||||||
wei,
|
wei,
|
||||||
out,
|
out,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -316,20 +280,20 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
const auto tmp = f_make_for_device_nhwc();
|
const auto tmp = f_make_for_device_nhwc();
|
||||||
|
|
||||||
device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<
|
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||||
in_data_t,
|
acc_data_t,
|
||||||
acc_data_t,
|
out_data_t>(
|
||||||
out_data_t>(tmp[I0],
|
tmp[I0],
|
||||||
tmp[I1],
|
tmp[I1],
|
||||||
tmp[I2],
|
tmp[I2],
|
||||||
tmp[I3],
|
tmp[I3],
|
||||||
tmp[I4],
|
tmp[I4],
|
||||||
tmp[I5],
|
tmp[I5],
|
||||||
tmp[I6],
|
tmp[I6],
|
||||||
in_device,
|
in_device,
|
||||||
wei,
|
wei,
|
||||||
out,
|
out,
|
||||||
nrepeat);
|
nrepeat);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -12,17 +12,17 @@
|
|||||||
#include "conv_common.hpp"
|
#include "conv_common.hpp"
|
||||||
#include "host_conv.hpp"
|
#include "host_conv.hpp"
|
||||||
#include "device_tensor.hpp"
|
#include "device_tensor.hpp"
|
||||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
#include "device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
||||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
|
#include "device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp"
|
||||||
#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
#include "device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
#include "device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||||
|
|
||||||
#define USE_DYNAMIC_MODE 1
|
#define USE_MODE 1
|
||||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
#define USE_CONV_FWD_V4R4_NCHW 1
|
||||||
#define USE_CONV_FWD_V4R4R2_NHWC 1
|
#define USE_CONV_FWD_V4R4R2_NHWC 1
|
||||||
#define USE_CONV_FWD_V6R1_NCHW 1
|
#define USE_CONV_FWD_V6R1_NCHW 0
|
||||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||||
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
|
||||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
|
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
|
||||||
@@ -49,7 +49,7 @@ int main(int argc, char* argv[])
|
|||||||
constexpr auto I5 = Number<5>{};
|
constexpr auto I5 = Number<5>{};
|
||||||
constexpr auto I6 = Number<6>{};
|
constexpr auto I6 = Number<6>{};
|
||||||
|
|
||||||
#if USE_DYNAMIC_MODE
|
#if USE_MODE
|
||||||
// dynamic mode
|
// dynamic mode
|
||||||
if(argc != 22)
|
if(argc != 22)
|
||||||
{
|
{
|
||||||
@@ -58,29 +58,29 @@ int main(int argc, char* argv[])
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2]));
|
||||||
const bool do_verification = atoi(argv[3]);
|
const bool do_verification = std::stoi(argv[3]);
|
||||||
const int init_method = atoi(argv[4]);
|
const int init_method = std::stoi(argv[4]);
|
||||||
const bool do_log = atoi(argv[5]);
|
const bool do_log = std::stoi(argv[5]);
|
||||||
const int nrepeat = atoi(argv[6]);
|
const int nrepeat = std::stoi(argv[6]);
|
||||||
|
|
||||||
const index_t N = atoi(argv[7]);
|
const index_t N = std::stoi(argv[7]);
|
||||||
const index_t K = atoi(argv[8]);
|
const index_t K = std::stoi(argv[8]);
|
||||||
const index_t C = atoi(argv[9]);
|
const index_t C = std::stoi(argv[9]);
|
||||||
const index_t Y = atoi(argv[10]);
|
const index_t Y = std::stoi(argv[10]);
|
||||||
const index_t X = atoi(argv[11]);
|
const index_t X = std::stoi(argv[11]);
|
||||||
const index_t Hi = atoi(argv[12]);
|
const index_t Hi = std::stoi(argv[12]);
|
||||||
const index_t Wi = atoi(argv[13]);
|
const index_t Wi = std::stoi(argv[13]);
|
||||||
|
|
||||||
const index_t conv_stride_h = atoi(argv[14]);
|
const index_t conv_stride_h = std::stoi(argv[14]);
|
||||||
const index_t conv_stride_w = atoi(argv[15]);
|
const index_t conv_stride_w = std::stoi(argv[15]);
|
||||||
const index_t conv_dilation_h = atoi(argv[16]);
|
const index_t conv_dilation_h = std::stoi(argv[16]);
|
||||||
const index_t conv_dilation_w = atoi(argv[17]);
|
const index_t conv_dilation_w = std::stoi(argv[17]);
|
||||||
const index_t in_left_pad_h = atoi(argv[18]);
|
const index_t in_left_pad_h = std::stoi(argv[18]);
|
||||||
const index_t in_left_pad_w = atoi(argv[19]);
|
const index_t in_left_pad_w = std::stoi(argv[19]);
|
||||||
const index_t in_right_pad_h = atoi(argv[20]);
|
const index_t in_right_pad_h = std::stoi(argv[20]);
|
||||||
const index_t in_right_pad_w = atoi(argv[21]);
|
const index_t in_right_pad_w = std::stoi(argv[21]);
|
||||||
|
|
||||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
||||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
||||||
@@ -95,12 +95,12 @@ int main(int argc, char* argv[])
|
|||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(std::stoi(argv[1]));
|
||||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(std::stoi(argv[2]));
|
||||||
const bool do_verification = atoi(argv[3]);
|
const bool do_verification = std::stoi(argv[3]);
|
||||||
const int init_method = atoi(argv[4]);
|
const int init_method = std::stoi(argv[4]);
|
||||||
const bool do_log = atoi(argv[5]);
|
const bool do_log = std::stoi(argv[5]);
|
||||||
const int nrepeat = atoi(argv[6]);
|
const int nrepeat = std::stoi(argv[6]);
|
||||||
|
|
||||||
constexpr index_t N = 128;
|
constexpr index_t N = 128;
|
||||||
constexpr index_t C = 192;
|
constexpr index_t C = 192;
|
||||||
@@ -142,10 +142,8 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||||
|
|
||||||
switch(layout)
|
if(layout == ConvTensorLayout::NCHW)
|
||||||
{
|
{
|
||||||
case ConvTensorLayout::NCHW:
|
|
||||||
// NCHW
|
|
||||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
in_lengths_host[1] = static_cast<std::size_t>(C);
|
||||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
||||||
@@ -158,9 +156,9 @@ int main(int argc, char* argv[])
|
|||||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
out_lengths_host[1] = static_cast<std::size_t>(K);
|
||||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
||||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
||||||
break;
|
}
|
||||||
case ConvTensorLayout::NHWC:
|
else if(layout == ConvTensorLayout::NHWC)
|
||||||
// NHWC
|
{
|
||||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
in_lengths_host[0] = static_cast<std::size_t>(N);
|
||||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
||||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
||||||
@@ -173,8 +171,10 @@ int main(int argc, char* argv[])
|
|||||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
||||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
||||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
out_lengths_host[3] = static_cast<std::size_t>(K);
|
||||||
break;
|
}
|
||||||
default: throw std::runtime_error("wrong! not implemented");
|
else
|
||||||
|
{
|
||||||
|
std::runtime_error("wrong! not implemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor<in_data_t> in(in_lengths_host);
|
Tensor<in_data_t> in(in_lengths_host);
|
||||||
@@ -228,7 +228,7 @@ int main(int argc, char* argv[])
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto f_make_for_device_nchw = [&]() {
|
auto f_make_for_device_nchw = [&]() {
|
||||||
#if USE_DYNAMIC_MODE
|
#if USE_MODE
|
||||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
||||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
||||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
||||||
@@ -260,7 +260,7 @@ int main(int argc, char* argv[])
|
|||||||
};
|
};
|
||||||
|
|
||||||
auto f_make_for_device_nhwc = [&]() {
|
auto f_make_for_device_nhwc = [&]() {
|
||||||
#if USE_DYNAMIC_MODE
|
#if USE_MODE
|
||||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
||||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
||||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
||||||
@@ -301,20 +301,19 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
const auto tmp = f_make_for_device_nchw();
|
const auto tmp = f_make_for_device_nchw();
|
||||||
|
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
|
device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||||
acc_data_t,
|
acc_data_t,
|
||||||
out_data_t>(
|
out_data_t>(tmp[I0],
|
||||||
tmp[I0],
|
tmp[I1],
|
||||||
tmp[I1],
|
tmp[I2],
|
||||||
tmp[I2],
|
tmp[I3],
|
||||||
tmp[I3],
|
tmp[I4],
|
||||||
tmp[I4],
|
tmp[I5],
|
||||||
tmp[I5],
|
tmp[I6],
|
||||||
tmp[I6],
|
in,
|
||||||
in,
|
wei,
|
||||||
wei,
|
out_device,
|
||||||
out_device,
|
nrepeat);
|
||||||
nrepeat);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -328,20 +327,19 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
const auto tmp = f_make_for_device_nhwc();
|
const auto tmp = f_make_for_device_nhwc();
|
||||||
|
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
|
device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk<in_data_t,
|
||||||
acc_data_t,
|
acc_data_t,
|
||||||
out_data_t>(
|
out_data_t>(tmp[I0],
|
||||||
tmp[I0],
|
tmp[I1],
|
||||||
tmp[I1],
|
tmp[I2],
|
||||||
tmp[I2],
|
tmp[I3],
|
||||||
tmp[I3],
|
tmp[I4],
|
||||||
tmp[I4],
|
tmp[I5],
|
||||||
tmp[I5],
|
tmp[I6],
|
||||||
tmp[I6],
|
in,
|
||||||
in,
|
wei,
|
||||||
wei,
|
out_device,
|
||||||
out_device,
|
nrepeat);
|
||||||
nrepeat);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -355,20 +353,19 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
const auto tmp = f_make_for_device_nchw();
|
const auto tmp = f_make_for_device_nchw();
|
||||||
|
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||||
acc_data_t,
|
acc_data_t,
|
||||||
out_data_t>(
|
out_data_t>(tmp[I0],
|
||||||
tmp[I0],
|
tmp[I1],
|
||||||
tmp[I1],
|
tmp[I2],
|
||||||
tmp[I2],
|
tmp[I3],
|
||||||
tmp[I3],
|
tmp[I4],
|
||||||
tmp[I4],
|
tmp[I5],
|
||||||
tmp[I5],
|
tmp[I6],
|
||||||
tmp[I6],
|
in,
|
||||||
in,
|
wei,
|
||||||
wei,
|
out_device,
|
||||||
out_device,
|
nrepeat);
|
||||||
nrepeat);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -382,21 +379,20 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
const auto tmp = f_make_for_device_nchw();
|
const auto tmp = f_make_for_device_nchw();
|
||||||
|
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw<in_data_t,
|
||||||
16,
|
16,
|
||||||
acc_data_t,
|
acc_data_t,
|
||||||
out_data_t>(
|
out_data_t>(tmp[I0],
|
||||||
tmp[I0],
|
tmp[I1],
|
||||||
tmp[I1],
|
tmp[I2],
|
||||||
tmp[I2],
|
tmp[I3],
|
||||||
tmp[I3],
|
tmp[I4],
|
||||||
tmp[I4],
|
tmp[I5],
|
||||||
tmp[I5],
|
tmp[I6],
|
||||||
tmp[I6],
|
in,
|
||||||
in,
|
wei,
|
||||||
wei,
|
out_device,
|
||||||
out_device,
|
nrepeat);
|
||||||
nrepeat);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -410,9 +406,9 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
const auto tmp = f_make_for_device_nchw();
|
const auto tmp = f_make_for_device_nchw();
|
||||||
|
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
|
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw<in_data_t,
|
||||||
acc_data_t,
|
acc_data_t,
|
||||||
out_data_t>(
|
out_data_t>(
|
||||||
tmp[I0],
|
tmp[I0],
|
||||||
tmp[I1],
|
tmp[I1],
|
||||||
tmp[I2],
|
tmp[I2],
|
||||||
@@ -437,9 +433,9 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
const auto tmp = f_make_for_device_nhwc();
|
const auto tmp = f_make_for_device_nhwc();
|
||||||
|
|
||||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk<in_data_t,
|
||||||
acc_data_t,
|
acc_data_t,
|
||||||
out_data_t>(
|
out_data_t>(
|
||||||
tmp[I0],
|
tmp[I0],
|
||||||
tmp[I1],
|
tmp[I1],
|
||||||
tmp[I2],
|
tmp[I2],
|
||||||
@@ -467,7 +463,6 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
check_error(out_host, out_device);
|
check_error(out_host, out_device);
|
||||||
|
|
||||||
#if 0
|
|
||||||
if(do_log)
|
if(do_log)
|
||||||
{
|
{
|
||||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||||
@@ -475,6 +470,5 @@ int main(int argc, char* argv[])
|
|||||||
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||||
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
include_directories(BEFORE
|
|
||||||
include
|
|
||||||
${PROJECT_BINARY_DIR}/host/online_compile/include
|
|
||||||
${PROJECT_SOURCE_DIR}/host/online_compile/include
|
|
||||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
|
||||||
${PROJECT_SOURCE_DIR}/host/solver/include
|
|
||||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
|
||||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
|
||||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
|
||||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
|
||||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
|
||||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
|
||||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
|
||||||
${PROJECT_SOURCE_DIR}/external/half/include
|
|
||||||
)
|
|
||||||
|
|
||||||
set(CONV_FWD_DRIVER_ONLINE_SOURCE conv_fwd_driver_online.cpp)
|
|
||||||
|
|
||||||
add_executable(conv_fwd_driver_online ${CONV_FWD_DRIVER_ONLINE_SOURCE})
|
|
||||||
|
|
||||||
target_link_libraries(conv_fwd_driver_online PRIVATE host_tensor)
|
|
||||||
target_link_libraries(conv_fwd_driver_online PRIVATE online_compile)
|
|
||||||
@@ -1,453 +0,0 @@
|
|||||||
#include <iostream>
|
|
||||||
#include <numeric>
|
|
||||||
#include <initializer_list>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <stdlib.h>
|
|
||||||
#include <half.hpp>
|
|
||||||
#include "config.hpp"
|
|
||||||
#include "print.hpp"
|
|
||||||
#include "device.hpp"
|
|
||||||
#include "host_tensor.hpp"
|
|
||||||
#include "host_tensor_generator.hpp"
|
|
||||||
#include "conv_common.hpp"
|
|
||||||
#include "host_conv.hpp"
|
|
||||||
#include "device_tensor.hpp"
|
|
||||||
#include "handle.hpp"
|
|
||||||
#include "hipCheck.hpp"
|
|
||||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
|
||||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
|
||||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
|
||||||
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
|
||||||
|
|
||||||
#define USE_CONV_FWD_V4R4_NCHW 1
|
|
||||||
#define USE_CONV_FWD_V6R1_NCHW 1
|
|
||||||
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
|
|
||||||
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
|
|
||||||
|
|
||||||
enum ConvForwardAlgo
|
|
||||||
{
|
|
||||||
V4R4NCHW, // 0
|
|
||||||
V6R1NCHW, // 1
|
|
||||||
V4R4XDLNCHW, // 2
|
|
||||||
V4R4XDLNHWC // 3
|
|
||||||
};
|
|
||||||
|
|
||||||
int main(int argc, char* argv[])
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
using namespace ck_driver;
|
|
||||||
using size_t = std::size_t;
|
|
||||||
|
|
||||||
hipStream_t stream;
|
|
||||||
online_compile::Handle* handle;
|
|
||||||
|
|
||||||
MY_HIP_CHECK(hipStreamCreate(&stream));
|
|
||||||
|
|
||||||
handle = new online_compile::Handle(stream);
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
constexpr auto I4 = Number<4>{};
|
|
||||||
constexpr auto I5 = Number<5>{};
|
|
||||||
constexpr auto I6 = Number<6>{};
|
|
||||||
|
|
||||||
if(argc != 22)
|
|
||||||
{
|
|
||||||
printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n");
|
|
||||||
printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n");
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
const ConvTensorLayout layout = static_cast<ConvTensorLayout>(atoi(argv[1]));
|
|
||||||
const ConvForwardAlgo algo = static_cast<ConvForwardAlgo>(atoi(argv[2]));
|
|
||||||
const bool do_verification = atoi(argv[3]);
|
|
||||||
const int init_method = atoi(argv[4]);
|
|
||||||
const bool do_log = atoi(argv[5]);
|
|
||||||
const int nrepeat = atoi(argv[6]);
|
|
||||||
|
|
||||||
const index_t N = atoi(argv[7]);
|
|
||||||
const index_t K = atoi(argv[8]);
|
|
||||||
const index_t C = atoi(argv[9]);
|
|
||||||
const index_t Y = atoi(argv[10]);
|
|
||||||
const index_t X = atoi(argv[11]);
|
|
||||||
const index_t Hi = atoi(argv[12]);
|
|
||||||
const index_t Wi = atoi(argv[13]);
|
|
||||||
|
|
||||||
const index_t conv_stride_h = atoi(argv[14]);
|
|
||||||
const index_t conv_stride_w = atoi(argv[15]);
|
|
||||||
const index_t conv_dilation_h = atoi(argv[16]);
|
|
||||||
const index_t conv_dilation_w = atoi(argv[17]);
|
|
||||||
const index_t in_left_pad_h = atoi(argv[18]);
|
|
||||||
const index_t in_left_pad_w = atoi(argv[19]);
|
|
||||||
const index_t in_right_pad_h = atoi(argv[20]);
|
|
||||||
const index_t in_right_pad_w = atoi(argv[21]);
|
|
||||||
|
|
||||||
const index_t YEff = (Y - 1) * conv_dilation_h + 1;
|
|
||||||
const index_t XEff = (X - 1) * conv_dilation_w + 1;
|
|
||||||
|
|
||||||
const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1;
|
|
||||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
|
||||||
|
|
||||||
#if 1
|
|
||||||
using in_data_t = float;
|
|
||||||
using acc_data_t = float;
|
|
||||||
using out_data_t = float;
|
|
||||||
#elif 0
|
|
||||||
using in_data_t = half_t;
|
|
||||||
using acc_data_t = float;
|
|
||||||
using out_data_t = half_t;
|
|
||||||
#elif 1
|
|
||||||
using in_data_t = int8_t;
|
|
||||||
using acc_data_t = int32_t;
|
|
||||||
using out_data_t = int8_t;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
|
||||||
|
|
||||||
switch(layout)
|
|
||||||
{
|
|
||||||
case ConvTensorLayout::NCHW:
|
|
||||||
// NCHW
|
|
||||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
|
||||||
in_lengths_host[1] = static_cast<std::size_t>(C);
|
|
||||||
in_lengths_host[2] = static_cast<std::size_t>(Hi);
|
|
||||||
in_lengths_host[3] = static_cast<std::size_t>(Wi);
|
|
||||||
|
|
||||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
|
||||||
wei_lengths_host[1] = static_cast<std::size_t>(C);
|
|
||||||
wei_lengths_host[2] = static_cast<std::size_t>(Y);
|
|
||||||
wei_lengths_host[3] = static_cast<std::size_t>(X);
|
|
||||||
|
|
||||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
|
||||||
out_lengths_host[1] = static_cast<std::size_t>(K);
|
|
||||||
out_lengths_host[2] = static_cast<std::size_t>(Ho);
|
|
||||||
out_lengths_host[3] = static_cast<std::size_t>(Wo);
|
|
||||||
break;
|
|
||||||
case ConvTensorLayout::NHWC:
|
|
||||||
// NHWC
|
|
||||||
in_lengths_host[0] = static_cast<std::size_t>(N);
|
|
||||||
in_lengths_host[1] = static_cast<std::size_t>(Hi);
|
|
||||||
in_lengths_host[2] = static_cast<std::size_t>(Wi);
|
|
||||||
in_lengths_host[3] = static_cast<std::size_t>(C);
|
|
||||||
|
|
||||||
wei_lengths_host[0] = static_cast<std::size_t>(K);
|
|
||||||
wei_lengths_host[1] = static_cast<std::size_t>(Y);
|
|
||||||
wei_lengths_host[2] = static_cast<std::size_t>(X);
|
|
||||||
wei_lengths_host[3] = static_cast<std::size_t>(C);
|
|
||||||
|
|
||||||
out_lengths_host[0] = static_cast<std::size_t>(N);
|
|
||||||
out_lengths_host[1] = static_cast<std::size_t>(Ho);
|
|
||||||
out_lengths_host[2] = static_cast<std::size_t>(Wo);
|
|
||||||
out_lengths_host[3] = static_cast<std::size_t>(K);
|
|
||||||
break;
|
|
||||||
default: throw std::runtime_error("wrong! not implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor<in_data_t> in(in_lengths_host);
|
|
||||||
Tensor<in_data_t> wei(wei_lengths_host);
|
|
||||||
Tensor<out_data_t> out_host(out_lengths_host);
|
|
||||||
Tensor<out_data_t> out_device(out_lengths_host);
|
|
||||||
|
|
||||||
std::cout << "layout: " << layout << std::endl;
|
|
||||||
ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: ");
|
|
||||||
ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: ");
|
|
||||||
ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: ");
|
|
||||||
print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w));
|
|
||||||
print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w));
|
|
||||||
print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w));
|
|
||||||
print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w));
|
|
||||||
|
|
||||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
|
||||||
|
|
||||||
switch(init_method)
|
|
||||||
{
|
|
||||||
case 0:
|
|
||||||
// no initialization
|
|
||||||
break;
|
|
||||||
case 1:
|
|
||||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
|
||||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
|
||||||
break;
|
|
||||||
case 2:
|
|
||||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
|
||||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
|
||||||
break;
|
|
||||||
case 3:
|
|
||||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
|
||||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
|
||||||
break;
|
|
||||||
case 4:
|
|
||||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
|
||||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
|
||||||
break;
|
|
||||||
case 5:
|
|
||||||
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
|
||||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
|
||||||
|
|
||||||
auto gen_wei = [](auto... is) {
|
|
||||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
|
||||||
};
|
|
||||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto f_make_for_device_nchw = [&]() {
|
|
||||||
const auto in_lengths_dev = make_tuple(N, C, Hi, Wi);
|
|
||||||
const auto wei_lengths_dev = make_tuple(K, C, Y, X);
|
|
||||||
const auto out_lengths_dev = make_tuple(N, K, Ho, Wo);
|
|
||||||
|
|
||||||
return make_tuple(in_lengths_dev, wei_lengths_dev, out_lengths_dev);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto f_make_for_device_nhwc = [&]() {
|
|
||||||
const auto in_lengths_dev = make_tuple(N, Hi, Wi, C);
|
|
||||||
const auto wei_lengths_dev = make_tuple(K, Y, X, C);
|
|
||||||
const auto out_lengths_dev = make_tuple(N, Ho, Wo, K);
|
|
||||||
|
|
||||||
return make_tuple(in_lengths_dev, wei_lengths_dev, out_lengths_dev);
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto conv_strides = make_tuple(conv_stride_h, conv_stride_w);
|
|
||||||
const auto conv_dilations = make_tuple(conv_dilation_h, conv_dilation_w);
|
|
||||||
const auto in_left_pads = make_tuple(in_left_pad_h, in_left_pad_w);
|
|
||||||
const auto in_right_pads = make_tuple(in_right_pad_h, in_right_pad_w);
|
|
||||||
|
|
||||||
#if USE_CONV_FWD_V4R4_NCHW
|
|
||||||
if(algo == ConvForwardAlgo::V4R4NCHW)
|
|
||||||
{
|
|
||||||
if(layout != ConvTensorLayout::NCHW)
|
|
||||||
{
|
|
||||||
throw std::runtime_error("wrong! layout");
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto tmp = f_make_for_device_nchw();
|
|
||||||
|
|
||||||
tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable =
|
|
||||||
&default_tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw;
|
|
||||||
|
|
||||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw<
|
|
||||||
in_data_t,
|
|
||||||
acc_data_t,
|
|
||||||
out_data_t>(handle,
|
|
||||||
tmp[I0],
|
|
||||||
tmp[I1],
|
|
||||||
tmp[I2],
|
|
||||||
conv_strides,
|
|
||||||
conv_dilations,
|
|
||||||
in_left_pads,
|
|
||||||
in_right_pads,
|
|
||||||
in,
|
|
||||||
wei,
|
|
||||||
out_device,
|
|
||||||
tunable,
|
|
||||||
nrepeat);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if USE_CONV_FWD_V6R1_NCHW
|
|
||||||
if(algo == ConvForwardAlgo::V6R1NCHW)
|
|
||||||
{
|
|
||||||
if(layout != ConvTensorLayout::NCHW)
|
|
||||||
{
|
|
||||||
throw std::runtime_error("wrong! layout");
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto tmp = f_make_for_device_nchw();
|
|
||||||
|
|
||||||
#if 1
|
|
||||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
|
|
||||||
get_datatype_enum_from_type<in_data_t>::value,
|
|
||||||
get_datatype_enum_from_type<acc_data_t>::value,
|
|
||||||
get_datatype_enum_from_type<out_data_t>::value,
|
|
||||||
256,
|
|
||||||
4,
|
|
||||||
1,
|
|
||||||
128,
|
|
||||||
32,
|
|
||||||
8,
|
|
||||||
4,
|
|
||||||
4,
|
|
||||||
1,
|
|
||||||
{8, 2},
|
|
||||||
{8, 2},
|
|
||||||
{4, 1, 1, 1, 1},
|
|
||||||
{2, 1, 1, 128, 1},
|
|
||||||
{4, 1, 1, 1, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
{1, 4, 1, 1, 1},
|
|
||||||
{8, 1, 1, 32, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
4,
|
|
||||||
true,
|
|
||||||
true};
|
|
||||||
#elif 0
|
|
||||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
|
|
||||||
get_datatype_enum_from_type<in_data_t>::value,
|
|
||||||
get_datatype_enum_from_type<acc_data_t>::value,
|
|
||||||
get_datatype_enum_from_type<out_data_t>::value,
|
|
||||||
256,
|
|
||||||
4,
|
|
||||||
2,
|
|
||||||
128,
|
|
||||||
32,
|
|
||||||
8,
|
|
||||||
4,
|
|
||||||
4,
|
|
||||||
1,
|
|
||||||
{8, 2},
|
|
||||||
{8, 2},
|
|
||||||
{4, 1, 1, 1, 2},
|
|
||||||
{2, 1, 1, 128, 1},
|
|
||||||
{4, 1, 1, 1, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
{1, 4, 1, 1, 2},
|
|
||||||
{8, 1, 1, 32, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
4,
|
|
||||||
true,
|
|
||||||
true};
|
|
||||||
#elif 1
|
|
||||||
const CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw compile_param = {
|
|
||||||
get_datatype_enum_from_type<in_data_t>::value,
|
|
||||||
get_datatype_enum_from_type<acc_data_t>::value,
|
|
||||||
get_datatype_enum_from_type<out_data_t>::value,
|
|
||||||
256,
|
|
||||||
4,
|
|
||||||
4,
|
|
||||||
128,
|
|
||||||
32,
|
|
||||||
8,
|
|
||||||
4,
|
|
||||||
4,
|
|
||||||
1,
|
|
||||||
{8, 2},
|
|
||||||
{8, 2},
|
|
||||||
{4, 1, 1, 1, 4},
|
|
||||||
{2, 1, 1, 128, 1},
|
|
||||||
{4, 1, 1, 1, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
{1, 4, 1, 1, 4},
|
|
||||||
{8, 1, 1, 32, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
{1, 1, 1, 1, 1},
|
|
||||||
4,
|
|
||||||
true,
|
|
||||||
true};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw<
|
|
||||||
in_data_t,
|
|
||||||
acc_data_t,
|
|
||||||
out_data_t>(handle,
|
|
||||||
tmp[I0],
|
|
||||||
tmp[I1],
|
|
||||||
tmp[I2],
|
|
||||||
conv_strides,
|
|
||||||
conv_dilations,
|
|
||||||
in_left_pads,
|
|
||||||
in_right_pads,
|
|
||||||
in,
|
|
||||||
wei,
|
|
||||||
out_device,
|
|
||||||
compile_param,
|
|
||||||
nrepeat);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if USE_CONV_FWD_V4R4_XDLOPS_NCHW
|
|
||||||
if(algo == ConvForwardAlgo::V4R4XDLNCHW)
|
|
||||||
{
|
|
||||||
if(layout != ConvTensorLayout::NCHW)
|
|
||||||
{
|
|
||||||
throw std::runtime_error("wrong! layout");
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto tmp = f_make_for_device_nchw();
|
|
||||||
|
|
||||||
tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable =
|
|
||||||
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
|
||||||
|
|
||||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw<
|
|
||||||
in_data_t,
|
|
||||||
acc_data_t,
|
|
||||||
out_data_t>(handle,
|
|
||||||
tmp[I0],
|
|
||||||
tmp[I1],
|
|
||||||
tmp[I2],
|
|
||||||
conv_strides,
|
|
||||||
conv_dilations,
|
|
||||||
in_left_pads,
|
|
||||||
in_right_pads,
|
|
||||||
in,
|
|
||||||
wei,
|
|
||||||
out_device,
|
|
||||||
tunable,
|
|
||||||
nrepeat);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if USE_CONV_FWD_V4R4_XDLOPS_NHWC
|
|
||||||
if(algo == ConvForwardAlgo::V4R4XDLNHWC)
|
|
||||||
{
|
|
||||||
if(layout != ConvTensorLayout::NHWC)
|
|
||||||
{
|
|
||||||
throw std::runtime_error("wrong! layout");
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto tmp = f_make_for_device_nhwc();
|
|
||||||
|
|
||||||
tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable =
|
|
||||||
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
|
|
||||||
|
|
||||||
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk<
|
|
||||||
in_data_t,
|
|
||||||
acc_data_t,
|
|
||||||
out_data_t>(handle,
|
|
||||||
tmp[I0],
|
|
||||||
tmp[I1],
|
|
||||||
tmp[I2],
|
|
||||||
conv_strides,
|
|
||||||
conv_dilations,
|
|
||||||
in_left_pads,
|
|
||||||
in_right_pads,
|
|
||||||
in,
|
|
||||||
wei,
|
|
||||||
out_device,
|
|
||||||
tunable,
|
|
||||||
nrepeat);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if(do_verification)
|
|
||||||
{
|
|
||||||
host_direct_convolution(in,
|
|
||||||
wei,
|
|
||||||
out_host,
|
|
||||||
make_tuple(conv_stride_h, conv_stride_w),
|
|
||||||
make_tuple(conv_dilation_h, conv_dilation_w),
|
|
||||||
make_tuple(in_left_pad_h, in_left_pad_w),
|
|
||||||
make_tuple(in_right_pad_h, in_right_pad_w),
|
|
||||||
layout);
|
|
||||||
|
|
||||||
check_error(out_host, out_device);
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
if(do_log)
|
|
||||||
{
|
|
||||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
|
||||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
|
||||||
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
|
||||||
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
delete handle;
|
|
||||||
MY_HIP_CHECK(hipStreamDestroy(stream));
|
|
||||||
}
|
|
||||||
@@ -1,395 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include "device.hpp"
|
|
||||||
#include "host_tensor.hpp"
|
|
||||||
#include "handle.hpp"
|
|
||||||
#include "online_driver_common.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
|
||||||
#include "conv_tunable_fwd_v4r4_dlops_nchw_kcyx_nkhw.hpp"
|
|
||||||
|
|
||||||
namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw {
|
|
||||||
|
|
||||||
template <typename TInWei, typename TAcc, typename TOut>
|
|
||||||
static std::string get_network_config_string_from_types()
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out += std::to_string(get_datatype_enum_from_type<TInWei>::value) + "_" +
|
|
||||||
std::to_string(get_datatype_enum_from_type<TAcc>::value) + "_" +
|
|
||||||
std::to_string(get_datatype_enum_from_type<TOut>::value);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::string
|
|
||||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt)
|
|
||||||
{
|
|
||||||
std::string out("TUN_");
|
|
||||||
|
|
||||||
out += std::to_string(pt->BlockSize) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" +
|
|
||||||
std::to_string(pt->KPerBlock) + "_";
|
|
||||||
out += std::to_string(pt->M1PerThread) + "x" + std::to_string(pt->N1PerThread) + "x" +
|
|
||||||
std::to_string(pt->KPerThread) + "_";
|
|
||||||
out += std::to_string(pt->M1N1ThreadClusterM10) + "x" +
|
|
||||||
std::to_string(pt->M1N1ThreadClusterN10) + "x" +
|
|
||||||
std::to_string(pt->M1N1ThreadClusterM11) + "x" +
|
|
||||||
std::to_string(pt->M1N1ThreadClusterN11) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
|
||||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_M1) + "_";
|
|
||||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
|
||||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_N1) + "_";
|
|
||||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TInWei, typename TAcc, typename TOut>
|
|
||||||
static std::string get_definition_string_from_types()
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TInWei>::value) +
|
|
||||||
" -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TAcc>::value) +
|
|
||||||
" -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TOut>::value);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::string
|
|
||||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* pt)
|
|
||||||
{
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) +
|
|
||||||
" -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) +
|
|
||||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
|
||||||
out += " -DCK_PARAM_M1PerThread=" + std::to_string(pt->M1PerThread) +
|
|
||||||
" -DCK_PARAM_N1PerThread=" + std::to_string(pt->N1PerThread) +
|
|
||||||
" -DCK_PARAM_KPerThread=" + std::to_string(pt->KPerThread);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_M1N1ThreadClusterM10=" + std::to_string(pt->M1N1ThreadClusterM10) +
|
|
||||||
" -DCK_PARAM_M1N1ThreadClusterN10=" + std::to_string(pt->M1N1ThreadClusterN10) +
|
|
||||||
" -DCK_PARAM_M1N1ThreadClusterM11=" + std::to_string(pt->M1N1ThreadClusterM11) +
|
|
||||||
" -DCK_PARAM_M1N1ThreadClusterN11=" + std::to_string(pt->M1N1ThreadClusterN11);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K_M0_M1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K_M0_M1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]);
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
|
||||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
|
||||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_M1=" +
|
|
||||||
std::to_string(pt->ABlockTransferDstScalarPerVector_M1);
|
|
||||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
|
||||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K_N0_N1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K_N0_N1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]);
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
|
||||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
|
||||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_N1=" +
|
|
||||||
std::to_string(pt->BBlockTransferDstScalarPerVector_N1);
|
|
||||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
|
||||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
|
||||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
|
||||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw
|
|
||||||
|
|
||||||
template <typename TInWei,
|
|
||||||
typename TAcc,
|
|
||||||
typename TOut,
|
|
||||||
typename InLengths,
|
|
||||||
typename WeiLengths,
|
|
||||||
typename OutLengths,
|
|
||||||
typename ConvStrides,
|
|
||||||
typename ConvDilations,
|
|
||||||
typename InLeftPads,
|
|
||||||
typename InRightPads>
|
|
||||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
|
||||||
online_compile::Handle* handle,
|
|
||||||
const InLengths& in_n_c_hi_wi_lengths,
|
|
||||||
const WeiLengths& wei_k_c_y_x_lengths,
|
|
||||||
const OutLengths& out_n_k_ho_wo_lengths,
|
|
||||||
const ConvStrides& conv_strides,
|
|
||||||
const ConvDilations& conv_dilations,
|
|
||||||
const InLeftPads& in_left_pads,
|
|
||||||
const InRightPads& in_right_pads,
|
|
||||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
|
||||||
const Tensor<TInWei>& wei_k_c_y_x,
|
|
||||||
Tensor<TOut>& out_n_k_ho_wo,
|
|
||||||
const tunable_dyn_conv_fwd_v4r4_dlops_nchw_kcyx_nkhw* tunable,
|
|
||||||
ck::index_t nrepeat)
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
using namespace ck_driver;
|
|
||||||
using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw;
|
|
||||||
using size_t = std::size_t;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
|
|
||||||
// hasDoubleTailKBlockLoop
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
const auto in_n_c_hi_wi_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
|
||||||
const auto wei_k_c_y_x_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
|
||||||
const auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
|
||||||
|
|
||||||
const auto descs =
|
|
||||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
|
||||||
in_n_c_hi_wi_desc,
|
|
||||||
out_n_k_ho_wo_desc,
|
|
||||||
conv_strides,
|
|
||||||
conv_dilations,
|
|
||||||
in_left_pads,
|
|
||||||
in_right_pads);
|
|
||||||
const auto a_k_m_grid_desc = descs[I0];
|
|
||||||
const auto c_m_n_grid_desc = descs[I2];
|
|
||||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
|
||||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
|
||||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
|
||||||
|
|
||||||
const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock);
|
|
||||||
const bool hasMainKBlockLoop = ((K + tunable->KPerBlock) / (2 * tunable->KPerBlock) > 1);
|
|
||||||
const bool hasDoubleTailKBlockLoop = ((K / tunable->KPerBlock) % 2 == 0);
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// these buffers are usually provided by the user application
|
|
||||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
|
||||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
|
||||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
|
||||||
|
|
||||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
|
||||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
|
||||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
|
||||||
|
|
||||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
|
||||||
// workspace API
|
|
||||||
DeviceMem workspace_buf(4096);
|
|
||||||
|
|
||||||
void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
|
||||||
void* b_k_n0_n1_grid_desc_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
|
||||||
void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
|
||||||
void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
|
||||||
|
|
||||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
|
||||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
|
||||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
|
||||||
|
|
||||||
std::string program_name =
|
|
||||||
"dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.cpp";
|
|
||||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_dlops_nchw";
|
|
||||||
|
|
||||||
std::string param = " -std=c++17 ";
|
|
||||||
std::string network_config;
|
|
||||||
|
|
||||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() + " " +
|
|
||||||
get_definition_string_from_tunable(tunable) +
|
|
||||||
" -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) +
|
|
||||||
" -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop);
|
|
||||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
|
||||||
get_network_config_string_from_tunable(tunable) + "_" +
|
|
||||||
std::to_string(hasMainKBlockLoop) + "_" +
|
|
||||||
std::to_string(hasDoubleTailKBlockLoop);
|
|
||||||
|
|
||||||
std::vector<float> kernel1_times;
|
|
||||||
std::vector<float> kernel2_times;
|
|
||||||
|
|
||||||
for(index_t i = 0; i < nrepeat; ++i)
|
|
||||||
{
|
|
||||||
KernelTimer timer1, timer2;
|
|
||||||
std::string kernel_name;
|
|
||||||
|
|
||||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare";
|
|
||||||
auto network_config_1 = network_config + "_1";
|
|
||||||
|
|
||||||
timer1.Start();
|
|
||||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
|
||||||
conv_strides[I0],
|
|
||||||
conv_strides[I1],
|
|
||||||
conv_dilations[I0],
|
|
||||||
conv_dilations[I1],
|
|
||||||
in_left_pads[I0],
|
|
||||||
in_left_pads[I1],
|
|
||||||
in_right_pads[I0],
|
|
||||||
in_right_pads[I1],
|
|
||||||
a_k_m0_m1_grid_desc_dev_buf,
|
|
||||||
b_k_n0_n1_grid_desc_dev_buf,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf);
|
|
||||||
timer1.End();
|
|
||||||
|
|
||||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw";
|
|
||||||
auto network_config_2 = network_config + "_2";
|
|
||||||
|
|
||||||
timer2.Start();
|
|
||||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
|
||||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
|
||||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
|
||||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
|
||||||
(const void*)(a_k_m0_m1_grid_desc_dev_buf),
|
|
||||||
(const void*)(b_k_n0_n1_grid_desc_dev_buf),
|
|
||||||
(const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf),
|
|
||||||
(const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf));
|
|
||||||
timer2.End();
|
|
||||||
|
|
||||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
|
||||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto ave_time1 =
|
|
||||||
std::accumulate(
|
|
||||||
std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus<float>{}) /
|
|
||||||
(nrepeat - 1);
|
|
||||||
auto ave_time2 =
|
|
||||||
std::accumulate(
|
|
||||||
std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus<float>{}) /
|
|
||||||
(nrepeat - 1);
|
|
||||||
|
|
||||||
const auto N = in_n_c_hi_wi_lengths[I0];
|
|
||||||
const auto C = in_n_c_hi_wi_lengths[I1];
|
|
||||||
|
|
||||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
|
||||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
|
||||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
|
||||||
|
|
||||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
|
||||||
const auto X = wei_k_c_y_x_lengths[I3];
|
|
||||||
|
|
||||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
|
||||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
|
||||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
|
||||||
};
|
|
||||||
|
|
||||||
// copy result back to host
|
|
||||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
|
||||||
}
|
|
||||||
@@ -1,386 +0,0 @@
|
|||||||
#include "device.hpp"
|
|
||||||
#include "host_tensor.hpp"
|
|
||||||
#include "handle.hpp"
|
|
||||||
#include "online_driver_common.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
|
||||||
|
|
||||||
namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw {
|
|
||||||
|
|
||||||
template <typename TInWei, typename TAcc, typename TOut>
|
|
||||||
static std::string get_network_config_string_from_types()
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out += std::to_string(get_datatype_enum_from_type<TInWei>::value) + "_" +
|
|
||||||
std::to_string(get_datatype_enum_from_type<TAcc>::value) + "_" +
|
|
||||||
std::to_string(get_datatype_enum_from_type<TOut>::value);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::string
|
|
||||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt)
|
|
||||||
{
|
|
||||||
std::string out("TUN_");
|
|
||||||
|
|
||||||
out += std::to_string(pt->BlockSize) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" +
|
|
||||||
std::to_string(pt->KPerBlock) + "_";
|
|
||||||
out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" +
|
|
||||||
std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" +
|
|
||||||
std::to_string(pt->K1) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
|
||||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_";
|
|
||||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
|
||||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_";
|
|
||||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TInWei, typename TAcc, typename TOut>
|
|
||||||
static std::string get_definition_string_from_types()
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TInWei>::value) +
|
|
||||||
" -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TAcc>::value) +
|
|
||||||
" -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TOut>::value);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::string
|
|
||||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* pt)
|
|
||||||
{
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) +
|
|
||||||
" -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) +
|
|
||||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
|
||||||
out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) +
|
|
||||||
" -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) +
|
|
||||||
" -DCK_PARAM_K1=" + std::to_string(pt->K1) +
|
|
||||||
" -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) +
|
|
||||||
" -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]);
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
|
||||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
|
||||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" +
|
|
||||||
std::to_string(pt->ABlockTransferDstScalarPerVector_K1);
|
|
||||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
|
||||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]);
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
|
||||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
|
||||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" +
|
|
||||||
std::to_string(pt->BBlockTransferDstScalarPerVector_K1);
|
|
||||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
|
||||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
|
||||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
|
||||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
|
|
||||||
|
|
||||||
template <typename TInWei,
|
|
||||||
typename TAcc,
|
|
||||||
typename TOut,
|
|
||||||
typename InLengths,
|
|
||||||
typename WeiLengths,
|
|
||||||
typename OutLengths,
|
|
||||||
typename ConvStrides,
|
|
||||||
typename ConvDilations,
|
|
||||||
typename InLeftPads,
|
|
||||||
typename InRightPads>
|
|
||||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
|
||||||
online_compile::Handle* handle,
|
|
||||||
const InLengths& in_n_c_hi_wi_lengths,
|
|
||||||
const WeiLengths& wei_k_c_y_x_lengths,
|
|
||||||
const OutLengths& out_n_k_ho_wo_lengths,
|
|
||||||
const ConvStrides& conv_strides,
|
|
||||||
const ConvDilations& conv_dilations,
|
|
||||||
const InLeftPads& in_left_pads,
|
|
||||||
const InRightPads& in_right_pads,
|
|
||||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
|
||||||
const Tensor<TInWei>& wei_k_c_y_x,
|
|
||||||
Tensor<TOut>& out_n_k_ho_wo,
|
|
||||||
const tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable,
|
|
||||||
ck::index_t nrepeat)
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
using namespace ck_driver;
|
|
||||||
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
|
||||||
using size_t = std::size_t;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
const auto in_n_c_hi_wi_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
|
||||||
const auto wei_k_c_y_x_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
|
||||||
const auto out_n_k_ho_wo_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
|
||||||
|
|
||||||
const auto n = in_n_c_hi_wi_desc.GetLength(I0);
|
|
||||||
const auto c = in_n_c_hi_wi_desc.GetLength(I1);
|
|
||||||
const auto hi = in_n_c_hi_wi_desc.GetLength(I2);
|
|
||||||
const auto wi = in_n_c_hi_wi_desc.GetLength(I3);
|
|
||||||
const auto k = wei_k_c_y_x_desc.GetLength(I0);
|
|
||||||
const auto y = wei_k_c_y_x_desc.GetLength(I2);
|
|
||||||
const auto x = wei_k_c_y_x_desc.GetLength(I3);
|
|
||||||
const auto ho = out_n_k_ho_wo_desc.GetLength(I2);
|
|
||||||
const auto wo = out_n_k_ho_wo_desc.GetLength(I3);
|
|
||||||
|
|
||||||
const auto M = k;
|
|
||||||
const auto N = n * ho * wo;
|
|
||||||
const auto K = c * y * x;
|
|
||||||
const auto K0 = K / tunable->K1;
|
|
||||||
|
|
||||||
const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock);
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
// these buffers are usually provided by the user application
|
|
||||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
|
||||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
|
||||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
|
||||||
|
|
||||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
|
||||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
|
||||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
|
||||||
|
|
||||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
|
||||||
// workspace API
|
|
||||||
DeviceMem workspace_buf(4096);
|
|
||||||
|
|
||||||
void* a_k_m0_m1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
|
||||||
void* b_k_n0_n1_grid_desc_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
|
||||||
void* c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
|
||||||
void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
|
||||||
|
|
||||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
|
||||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
|
||||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
|
||||||
|
|
||||||
std::string program_name =
|
|
||||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp";
|
|
||||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nchw";
|
|
||||||
|
|
||||||
std::string param = " -std=c++17 ";
|
|
||||||
std::string network_config;
|
|
||||||
|
|
||||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() + " " + " -DCK_USE_AMD_XDLOPS" +
|
|
||||||
get_definition_string_from_tunable(tunable);
|
|
||||||
|
|
||||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
|
||||||
get_network_config_string_from_tunable(tunable);
|
|
||||||
|
|
||||||
std::vector<float> kernel1_times;
|
|
||||||
std::vector<float> kernel2_times;
|
|
||||||
|
|
||||||
for(index_t i = 0; i < nrepeat; ++i)
|
|
||||||
{
|
|
||||||
KernelTimer timer1, timer2;
|
|
||||||
std::string kernel_name;
|
|
||||||
|
|
||||||
kernel_name =
|
|
||||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare";
|
|
||||||
auto network_config_1 = network_config + "_1";
|
|
||||||
|
|
||||||
timer1.Start();
|
|
||||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
|
||||||
conv_strides[I0],
|
|
||||||
conv_strides[I1],
|
|
||||||
conv_dilations[I0],
|
|
||||||
conv_dilations[I1],
|
|
||||||
in_left_pads[I0],
|
|
||||||
in_left_pads[I1],
|
|
||||||
in_right_pads[I0],
|
|
||||||
in_right_pads[I1],
|
|
||||||
a_k_m0_m1_grid_desc_dev_buf,
|
|
||||||
b_k_n0_n1_grid_desc_dev_buf,
|
|
||||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf);
|
|
||||||
timer1.End();
|
|
||||||
|
|
||||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw";
|
|
||||||
auto network_config_2 = network_config + "_2";
|
|
||||||
|
|
||||||
timer2.Start();
|
|
||||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
|
||||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
|
||||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
|
||||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
|
||||||
(const void*)(a_k_m0_m1_grid_desc_dev_buf),
|
|
||||||
(const void*)(b_k_n0_n1_grid_desc_dev_buf),
|
|
||||||
(const void*)(c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf),
|
|
||||||
(const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf));
|
|
||||||
timer2.End();
|
|
||||||
|
|
||||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
|
||||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto ave_time1 =
|
|
||||||
std::accumulate(
|
|
||||||
std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus<float>{}) /
|
|
||||||
(nrepeat - 1);
|
|
||||||
auto ave_time2 =
|
|
||||||
std::accumulate(
|
|
||||||
std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus<float>{}) /
|
|
||||||
(nrepeat - 1);
|
|
||||||
|
|
||||||
const auto N = in_n_c_hi_wi_lengths[I0];
|
|
||||||
const auto C = in_n_c_hi_wi_lengths[I1];
|
|
||||||
|
|
||||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
|
||||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
|
||||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
|
||||||
|
|
||||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
|
||||||
const auto X = wei_k_c_y_x_lengths[I3];
|
|
||||||
|
|
||||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
|
||||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
|
||||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
|
||||||
};
|
|
||||||
|
|
||||||
// copy result back to host
|
|
||||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
|
||||||
}
|
|
||||||
@@ -1,389 +0,0 @@
|
|||||||
#include "device.hpp"
|
|
||||||
#include "host_tensor.hpp"
|
|
||||||
#include "handle.hpp"
|
|
||||||
#include "online_driver_common.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
|
||||||
#include "conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
|
|
||||||
|
|
||||||
namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk {
|
|
||||||
|
|
||||||
template <typename TInWei, typename TAcc, typename TOut>
|
|
||||||
static std::string get_network_config_string_from_types()
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out += std::to_string(get_datatype_enum_from_type<TInWei>::value) + "_" +
|
|
||||||
std::to_string(get_datatype_enum_from_type<TAcc>::value) + "_" +
|
|
||||||
std::to_string(get_datatype_enum_from_type<TOut>::value);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::string
|
|
||||||
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt)
|
|
||||||
{
|
|
||||||
std::string out("TUN_");
|
|
||||||
|
|
||||||
out += std::to_string(pt->BlockSize) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->MPerBlock) + "x" + std::to_string(pt->NPerBlock) + "x" +
|
|
||||||
std::to_string(pt->KPerBlock) + "_";
|
|
||||||
out += std::to_string(pt->MPerWave) + "x" + std::to_string(pt->NPerWave) + "x" +
|
|
||||||
std::to_string(pt->MRepeat) + "x" + std::to_string(pt->NRepeat) + "x" +
|
|
||||||
std::to_string(pt->K1) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
|
|
||||||
out += std::to_string(pt->ABlockTransferDstScalarPerVector_K1) + "_";
|
|
||||||
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
|
|
||||||
out += std::to_string(pt->BBlockTransferDstScalarPerVector_K1) + "_";
|
|
||||||
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "x" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]) + "_";
|
|
||||||
|
|
||||||
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
|
|
||||||
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TInWei, typename TAcc, typename TOut>
|
|
||||||
static std::string get_definition_string_from_types()
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_ABDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TInWei>::value) +
|
|
||||||
" -DCK_PARAM_AccDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TAcc>::value) +
|
|
||||||
" -DCK_PARAM_CDataTypeEnum=" + std::to_string(get_datatype_enum_from_type<TOut>::value);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
static std::string
|
|
||||||
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* pt)
|
|
||||||
{
|
|
||||||
std::string out;
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_MPerBlock=" + std::to_string(pt->MPerBlock) +
|
|
||||||
" -DCK_PARAM_NPerBlock=" + std::to_string(pt->NPerBlock) +
|
|
||||||
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
|
|
||||||
out += " -DCK_PARAM_MPerWave=" + std::to_string(pt->MPerWave) +
|
|
||||||
" -DCK_PARAM_NPerWave=" + std::to_string(pt->NPerWave) +
|
|
||||||
" -DCK_PARAM_K1=" + std::to_string(pt->K1) +
|
|
||||||
" -DCK_PARAM_MRepeat=" + std::to_string(pt->MRepeat) +
|
|
||||||
" -DCK_PARAM_NRepeat=" + std::to_string(pt->NRepeat);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadSliceLengths_K0_M_K1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterLengths_K0_M_K1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->ABlockTransferSrcAccessOrder[2]);
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
|
|
||||||
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
|
|
||||||
std::to_string(pt->ABlockTransferSrcScalarPerVector);
|
|
||||||
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_K1=" +
|
|
||||||
std::to_string(pt->ABlockTransferDstScalarPerVector_K1);
|
|
||||||
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
|
|
||||||
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadSliceLengths_K0_N_K1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterLengths_K0_N_K1[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->BBlockTransferSrcAccessOrder[2]);
|
|
||||||
|
|
||||||
out +=
|
|
||||||
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
|
|
||||||
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
|
|
||||||
std::to_string(pt->BBlockTransferSrcScalarPerVector);
|
|
||||||
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_K1=" +
|
|
||||||
std::to_string(pt->BBlockTransferDstScalarPerVector_K1);
|
|
||||||
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
|
|
||||||
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[6]) + "," +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstAccessOrder[7]);
|
|
||||||
|
|
||||||
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
|
|
||||||
std::to_string(pt->CThreadTransferSrcDstVectorDim);
|
|
||||||
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
|
|
||||||
std::to_string(pt->CThreadTransferDstScalarPerVector);
|
|
||||||
|
|
||||||
return (out);
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
|
|
||||||
|
|
||||||
template <typename TInWei,
|
|
||||||
typename TAcc,
|
|
||||||
typename TOut,
|
|
||||||
typename InLengths,
|
|
||||||
typename WeiLengths,
|
|
||||||
typename OutLengths,
|
|
||||||
typename ConvStrides,
|
|
||||||
typename ConvDilations,
|
|
||||||
typename InLeftPads,
|
|
||||||
typename InRightPads>
|
|
||||||
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
|
||||||
online_compile::Handle* handle,
|
|
||||||
const InLengths& in_n_hi_wi_c_lengths,
|
|
||||||
const WeiLengths& wei_k_y_x_c_lengths,
|
|
||||||
const OutLengths& out_n_ho_wo_k_lengths,
|
|
||||||
const ConvStrides& conv_strides,
|
|
||||||
const ConvDilations& conv_dilations,
|
|
||||||
const InLeftPads& in_left_pads,
|
|
||||||
const InRightPads& in_right_pads,
|
|
||||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
|
||||||
const Tensor<TInWei>& wei_k_y_x_c,
|
|
||||||
Tensor<TOut>& out_n_ho_wo_k,
|
|
||||||
const tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable,
|
|
||||||
ck::index_t nrepeat)
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
|
|
||||||
using size_t = std::size_t;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
|
|
||||||
// hasDoubleTailKBlockLoop
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
const auto in_n_hi_wi_c_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
|
||||||
const auto wei_k_y_x_c_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
|
||||||
const auto out_n_ho_wo_k_desc =
|
|
||||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
|
||||||
|
|
||||||
const auto n = in_n_hi_wi_c_desc.GetLength(I0);
|
|
||||||
const auto hi = in_n_hi_wi_c_desc.GetLength(I1);
|
|
||||||
const auto wi = in_n_hi_wi_c_desc.GetLength(I2);
|
|
||||||
const auto c = in_n_hi_wi_c_desc.GetLength(I3);
|
|
||||||
|
|
||||||
const auto k = wei_k_y_x_c_desc.GetLength(I0);
|
|
||||||
const auto y = wei_k_y_x_c_desc.GetLength(I1);
|
|
||||||
const auto x = wei_k_y_x_c_desc.GetLength(I2);
|
|
||||||
|
|
||||||
const auto ho = out_n_ho_wo_k_desc.GetLength(I1);
|
|
||||||
const auto wo = out_n_ho_wo_k_desc.GetLength(I2);
|
|
||||||
|
|
||||||
const auto M = k;
|
|
||||||
const auto N = n * ho * wo;
|
|
||||||
const auto K = c * y * x;
|
|
||||||
const auto K0 = K / tunable->K1;
|
|
||||||
|
|
||||||
const index_t grid_size = (M / tunable->MPerBlock) * (N / tunable->NPerBlock);
|
|
||||||
|
|
||||||
// these buffers are usually provided by the user application
|
|
||||||
DeviceMem in_n_hi_wi_c_dev_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
|
||||||
DeviceMem wei_k_y_x_c_dev_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
|
||||||
DeviceMem out_n_ho_wo_k_dev_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
|
||||||
|
|
||||||
in_n_hi_wi_c_dev_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
|
||||||
wei_k_y_x_c_dev_buf.ToDevice(wei_k_y_x_c.mData.data());
|
|
||||||
out_n_ho_wo_k_dev_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
|
||||||
|
|
||||||
// these are workspace buffers that should be expressed to the user by the corresponding
|
|
||||||
// workspace API
|
|
||||||
DeviceMem workspace_buf(4096);
|
|
||||||
|
|
||||||
void* a_k0_m_k1_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
|
|
||||||
void* b_k0_n_k1_grid_desc_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
|
|
||||||
void* c_m0_m1_m2_n_grid_desc_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
|
|
||||||
void* c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf =
|
|
||||||
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
|
|
||||||
|
|
||||||
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
|
||||||
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
|
|
||||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
|
|
||||||
|
|
||||||
std::string program_name =
|
|
||||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp";
|
|
||||||
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_xdlops_nhwc";
|
|
||||||
|
|
||||||
std::string param = " -std=c++17 ";
|
|
||||||
std::string network_config;
|
|
||||||
|
|
||||||
param += get_definition_string_from_types<TInWei, TAcc, TOut>() + " -DCK_USE_AMD_XDLOPS ";
|
|
||||||
param += get_definition_string_from_tunable(tunable);
|
|
||||||
|
|
||||||
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
|
|
||||||
get_network_config_string_from_tunable(tunable);
|
|
||||||
|
|
||||||
std::vector<float> kernel1_times;
|
|
||||||
std::vector<float> kernel2_times;
|
|
||||||
|
|
||||||
for(index_t i = 0; i < nrepeat; ++i)
|
|
||||||
{
|
|
||||||
KernelTimer timer1, timer2;
|
|
||||||
std::string kernel_name;
|
|
||||||
|
|
||||||
kernel_name =
|
|
||||||
"dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare";
|
|
||||||
auto network_config_1 = network_config + "_1";
|
|
||||||
|
|
||||||
timer1.Start();
|
|
||||||
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
|
|
||||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I0]),
|
|
||||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I1]),
|
|
||||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I2]),
|
|
||||||
static_cast<index_t>(in_n_hi_wi_c_lengths[I3]),
|
|
||||||
static_cast<index_t>(wei_k_y_x_c_lengths[I0]),
|
|
||||||
static_cast<index_t>(wei_k_y_x_c_lengths[I1]),
|
|
||||||
static_cast<index_t>(wei_k_y_x_c_lengths[I2]),
|
|
||||||
conv_strides[I0],
|
|
||||||
conv_strides[I1],
|
|
||||||
conv_dilations[I0],
|
|
||||||
conv_dilations[I1],
|
|
||||||
in_left_pads[I0],
|
|
||||||
in_left_pads[I1],
|
|
||||||
in_right_pads[I0],
|
|
||||||
in_right_pads[I1],
|
|
||||||
a_k0_m_k1_grid_desc_dev_buf,
|
|
||||||
b_k0_n_k1_grid_desc_dev_buf,
|
|
||||||
c_m0_m1_m2_n_grid_desc_dev_buf,
|
|
||||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf);
|
|
||||||
timer1.End();
|
|
||||||
|
|
||||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk";
|
|
||||||
auto network_config_2 = network_config + "_2";
|
|
||||||
|
|
||||||
timer2.Start();
|
|
||||||
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
|
|
||||||
reinterpret_cast<const TInWei*>(in_n_hi_wi_c_dev_buf.GetDeviceBuffer()),
|
|
||||||
reinterpret_cast<const TInWei*>(wei_k_y_x_c_dev_buf.GetDeviceBuffer()),
|
|
||||||
reinterpret_cast<TOut*>(out_n_ho_wo_k_dev_buf.GetDeviceBuffer()),
|
|
||||||
(const void*)(a_k0_m_k1_grid_desc_dev_buf),
|
|
||||||
(const void*)(b_k0_n_k1_grid_desc_dev_buf),
|
|
||||||
(const void*)(c_m0_m1_m2_n_grid_desc_dev_buf),
|
|
||||||
(const void*)(c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf));
|
|
||||||
timer2.End();
|
|
||||||
|
|
||||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
|
||||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto ave_time1 =
|
|
||||||
std::accumulate(
|
|
||||||
std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus<float>{}) /
|
|
||||||
(nrepeat - 1);
|
|
||||||
auto ave_time2 =
|
|
||||||
std::accumulate(
|
|
||||||
std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus<float>{}) /
|
|
||||||
(nrepeat - 1);
|
|
||||||
|
|
||||||
const auto N = in_n_hi_wi_c_lengths[I0];
|
|
||||||
const auto C = in_n_hi_wi_c_lengths[I3];
|
|
||||||
|
|
||||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
|
||||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
|
||||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
|
||||||
|
|
||||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
|
||||||
const auto X = wei_k_y_x_c_lengths[I2];
|
|
||||||
|
|
||||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
|
||||||
(std::size_t(1000) * 1000 * 1000) / ave_time2;
|
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
|
||||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
|
||||||
};
|
|
||||||
|
|
||||||
// copy result back to host
|
|
||||||
out_n_ho_wo_k_dev_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
|
||||||
}
|
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include "device.hpp"
|
|
||||||
#include "host_tensor.hpp"
|
|
||||||
#include "handle.hpp"
|
|
||||||
#include "online_driver_common.hpp"
|
|
||||||
#include "convolution_problem_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
|
||||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
|
||||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
|
||||||
#include "conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp"
|
|
||||||
|
|
||||||
template <typename TInWei,
|
|
||||||
typename TAcc,
|
|
||||||
typename TOut,
|
|
||||||
typename InLengths,
|
|
||||||
typename WeiLengths,
|
|
||||||
typename OutLengths,
|
|
||||||
typename ConvStrides,
|
|
||||||
typename ConvDilations,
|
|
||||||
typename InLeftPads,
|
|
||||||
typename InRightPads>
|
|
||||||
void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
|
||||||
online_compile::Handle* handle,
|
|
||||||
const InLengths& in_n_c_hi_wi_lengths,
|
|
||||||
const WeiLengths& wei_k_c_y_x_lengths,
|
|
||||||
const OutLengths& out_n_k_ho_wo_lengths,
|
|
||||||
const ConvStrides& conv_strides,
|
|
||||||
const ConvDilations& conv_dilations,
|
|
||||||
const InLeftPads& in_left_pads,
|
|
||||||
const InRightPads& in_right_pads,
|
|
||||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
|
||||||
const Tensor<TInWei>& wei_k_c_y_x,
|
|
||||||
Tensor<TOut>& out_n_k_ho_wo,
|
|
||||||
const ck_driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param,
|
|
||||||
ck::index_t nrepeat)
|
|
||||||
{
|
|
||||||
using namespace ck;
|
|
||||||
using namespace ck_driver;
|
|
||||||
using size_t = std::size_t;
|
|
||||||
|
|
||||||
std::cout << __func__ << std::endl;
|
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
|
||||||
constexpr auto I1 = Number<1>{};
|
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
ConvolutionProblemDescriptor conv_problem_desc{in_n_c_hi_wi_lengths[I0],
|
|
||||||
out_n_k_ho_wo_lengths[I1],
|
|
||||||
in_n_c_hi_wi_lengths[I1],
|
|
||||||
wei_k_c_y_x_lengths[I2],
|
|
||||||
wei_k_c_y_x_lengths[I3],
|
|
||||||
in_n_c_hi_wi_lengths[I2],
|
|
||||||
in_n_c_hi_wi_lengths[I3],
|
|
||||||
out_n_k_ho_wo_lengths[I2],
|
|
||||||
out_n_k_ho_wo_lengths[I3],
|
|
||||||
conv_strides[I0],
|
|
||||||
conv_strides[I1],
|
|
||||||
conv_dilations[I0],
|
|
||||||
conv_dilations[I1],
|
|
||||||
in_left_pads[I0],
|
|
||||||
in_left_pads[I1],
|
|
||||||
in_right_pads[I0],
|
|
||||||
in_right_pads[I1],
|
|
||||||
get_datatype_enum_from_type<TInWei>::value,
|
|
||||||
get_datatype_enum_from_type<TInWei>::value,
|
|
||||||
get_datatype_enum_from_type<TOut>::value};
|
|
||||||
|
|
||||||
if(!ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::IsValidCompileParameter(conv_problem_desc,
|
|
||||||
compile_param))
|
|
||||||
{
|
|
||||||
throw std::runtime_error("wrong! IsValidCompileParameter fail");
|
|
||||||
}
|
|
||||||
|
|
||||||
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
|
||||||
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
|
||||||
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
|
||||||
|
|
||||||
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
|
||||||
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
|
|
||||||
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
|
||||||
|
|
||||||
// workspace is used for save transformed tensor descritpors created by prepare kernel
|
|
||||||
DeviceMem workspace_dev_buf(
|
|
||||||
ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetWorkSpaceSize(conv_problem_desc, compile_param));
|
|
||||||
|
|
||||||
const auto block_size = std::size_t(
|
|
||||||
ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetBlockSize(conv_problem_desc, compile_param));
|
|
||||||
|
|
||||||
const auto grid_size = std::size_t(
|
|
||||||
ConvIgemmFwdV6r1DlopsNchwKcyxNkhw::GetGridSize(conv_problem_desc, compile_param));
|
|
||||||
|
|
||||||
const std::vector<size_t> vld1 = {1, 1, 1};
|
|
||||||
const std::vector<size_t> vgd1 = {1, 1, 1};
|
|
||||||
|
|
||||||
const std::vector<size_t> vld2 = {static_cast<size_t>(block_size), 1, 1};
|
|
||||||
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * block_size), 1, 1};
|
|
||||||
|
|
||||||
std::string program_name =
|
|
||||||
"dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp";
|
|
||||||
std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw";
|
|
||||||
|
|
||||||
std::string compile_param_string = get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString();
|
|
||||||
std::string network_config = compile_param_string;
|
|
||||||
|
|
||||||
std::vector<float> kernel1_times;
|
|
||||||
std::vector<float> kernel2_times;
|
|
||||||
|
|
||||||
for(index_t i = 0; i < nrepeat + 1; ++i)
|
|
||||||
{
|
|
||||||
KernelTimer timer1, timer2;
|
|
||||||
std::string kernel_name;
|
|
||||||
|
|
||||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare";
|
|
||||||
auto network_config_1 = network_config + "_1";
|
|
||||||
|
|
||||||
timer1.Start();
|
|
||||||
handle->AddKernel(algo_name,
|
|
||||||
network_config_1,
|
|
||||||
program_name,
|
|
||||||
kernel_name,
|
|
||||||
vld1,
|
|
||||||
vgd1,
|
|
||||||
compile_param_string)(static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
|
|
||||||
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
|
|
||||||
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
|
|
||||||
conv_strides[I0],
|
|
||||||
conv_strides[I1],
|
|
||||||
conv_dilations[I0],
|
|
||||||
conv_dilations[I1],
|
|
||||||
in_left_pads[I0],
|
|
||||||
in_left_pads[I1],
|
|
||||||
in_right_pads[I0],
|
|
||||||
in_right_pads[I1],
|
|
||||||
(void*)(workspace_dev_buf.GetDeviceBuffer()));
|
|
||||||
timer1.End();
|
|
||||||
|
|
||||||
kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw";
|
|
||||||
auto network_config_2 = network_config + "_2";
|
|
||||||
|
|
||||||
timer2.Start();
|
|
||||||
handle->AddKernel(algo_name,
|
|
||||||
network_config_2,
|
|
||||||
program_name,
|
|
||||||
kernel_name,
|
|
||||||
vld2,
|
|
||||||
vgd2,
|
|
||||||
compile_param_string)(
|
|
||||||
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
|
|
||||||
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
|
|
||||||
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
|
|
||||||
(const void*)(workspace_dev_buf.GetDeviceBuffer()));
|
|
||||||
timer2.End();
|
|
||||||
|
|
||||||
kernel1_times.push_back(timer1.GetElapsedTime());
|
|
||||||
kernel2_times.push_back(timer2.GetElapsedTime());
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
auto ave_time1 =
|
|
||||||
std::accumulate(
|
|
||||||
std::next(kernel1_times.begin()), kernel1_times.end(), 0., std::plus<float>{}) /
|
|
||||||
nrepeat;
|
|
||||||
auto ave_time2 =
|
|
||||||
std::accumulate(
|
|
||||||
std::next(kernel2_times.begin()), kernel2_times.end(), 0., std::plus<float>{}) /
|
|
||||||
nrepeat;
|
|
||||||
|
|
||||||
float perf = (float)(conv_problem_desc.CalculateFlop()) /
|
|
||||||
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
|
|
||||||
|
|
||||||
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
|
|
||||||
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
|
|
||||||
};
|
|
||||||
|
|
||||||
// copy result back to host
|
|
||||||
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
|
||||||
}
|
|
||||||
@@ -10,6 +10,8 @@ set(HOST_TENSOR_SOURCE
|
|||||||
## the library target
|
## the library target
|
||||||
add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE})
|
add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE})
|
||||||
|
|
||||||
|
target_include_directories(host_tensor SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
|
||||||
|
|
||||||
target_link_libraries(host_tensor PRIVATE hip::device)
|
target_link_libraries(host_tensor PRIVATE hip::device)
|
||||||
target_link_libraries(host_tensor INTERFACE hip::host)
|
target_link_libraries(host_tensor INTERFACE hip::host)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#ifndef CONV_COMMON_HPP
|
#ifndef CONV_COMMON_HPP
|
||||||
#define CONV_COMMON_HPP
|
#define CONV_COMMON_HPP
|
||||||
|
|
||||||
#include "dynamic_tensor_descriptor.hpp"
|
#include "tensor_descriptor.hpp"
|
||||||
|
|
||||||
enum ConvTensorLayout
|
enum ConvTensorLayout
|
||||||
{
|
{
|
||||||
@@ -19,8 +19,8 @@ template <typename... InDesc,
|
|||||||
typename LeftPads,
|
typename LeftPads,
|
||||||
typename RightPads>
|
typename RightPads>
|
||||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||||
const ck::DynamicTensorDescriptor<InDesc...>& in_desc,
|
const ck::TensorDescriptor<InDesc...>& in_desc,
|
||||||
const ck::DynamicTensorDescriptor<WeiDesc...>& wei_desc,
|
const ck::TensorDescriptor<WeiDesc...>& wei_desc,
|
||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations conv_dilations,
|
const ConvDilations conv_dilations,
|
||||||
const LeftPads& left_pads,
|
const LeftPads& left_pads,
|
||||||
@@ -57,12 +57,12 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
|||||||
const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
|
const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
|
||||||
const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
|
const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
|
||||||
|
|
||||||
return make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class InDesc, class WeiDesc, class OutDesc>
|
template <class InDesc, class WeiDesc, class OutDesc>
|
||||||
constexpr std::size_t
|
constexpr std::size_t
|
||||||
calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc)
|
calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc)
|
||||||
{
|
{
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
|
|
||||||
|
|||||||
@@ -34,24 +34,16 @@ struct KernelTimer
|
|||||||
using device_stream_t = hipStream_t;
|
using device_stream_t = hipStream_t;
|
||||||
|
|
||||||
template <typename... Args, typename F>
|
template <typename... Args, typename F>
|
||||||
void launch_kernel(F kernel,
|
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||||
dim3 grid_dim,
|
|
||||||
dim3 block_dim,
|
|
||||||
std::size_t lds_byte,
|
|
||||||
hipStream_t stream_id,
|
|
||||||
Args... args)
|
|
||||||
{
|
{
|
||||||
|
hipStream_t stream_id = nullptr;
|
||||||
|
|
||||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Args, typename F>
|
template <typename... Args, typename F>
|
||||||
float launch_and_time_kernel(F kernel,
|
float launch_and_time_kernel(
|
||||||
int nrepeat,
|
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||||
dim3 grid_dim,
|
|
||||||
dim3 block_dim,
|
|
||||||
std::size_t lds_byte,
|
|
||||||
hipStream_t stream_id,
|
|
||||||
Args... args)
|
|
||||||
{
|
{
|
||||||
KernelTimer timer;
|
KernelTimer timer;
|
||||||
|
|
||||||
@@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel,
|
|||||||
|
|
||||||
printf("Warm up\n");
|
printf("Warm up\n");
|
||||||
|
|
||||||
|
hipStream_t stream_id = nullptr;
|
||||||
|
|
||||||
// warm up
|
// warm up
|
||||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||||
|
|
||||||
|
|||||||
@@ -14,15 +14,13 @@ void host_direct_convolution(const Tensor<TIn>& in,
|
|||||||
const ConvStrides& conv_strides,
|
const ConvStrides& conv_strides,
|
||||||
const ConvDilations& conv_dilations,
|
const ConvDilations& conv_dilations,
|
||||||
const InLeftPads& in_left_pads,
|
const InLeftPads& in_left_pads,
|
||||||
const InRightPads& in_right_pads,
|
const InRightPads&,
|
||||||
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
const ConvTensorLayout layout = ConvTensorLayout::NCHW)
|
||||||
{
|
{
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
|
|
||||||
constexpr auto I0 = Number<0>{};
|
constexpr auto I0 = Number<0>{};
|
||||||
constexpr auto I1 = Number<1>{};
|
constexpr auto I1 = Number<1>{};
|
||||||
constexpr auto I2 = Number<2>{};
|
|
||||||
constexpr auto I3 = Number<3>{};
|
|
||||||
|
|
||||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||||
double v = 0;
|
double v = 0;
|
||||||
@@ -68,23 +66,25 @@ void host_direct_convolution(const Tensor<TIn>& in,
|
|||||||
out(n, ho, wo, k) = v;
|
out(n, ho, wo, k) = v;
|
||||||
};
|
};
|
||||||
|
|
||||||
switch(layout)
|
if(layout == ConvTensorLayout::NCHW)
|
||||||
{
|
{
|
||||||
case ConvTensorLayout::NCHW:
|
|
||||||
make_ParallelTensorFunctor(f_nchw,
|
make_ParallelTensorFunctor(f_nchw,
|
||||||
out.mDesc.GetLengths()[0],
|
out.mDesc.GetLengths()[0],
|
||||||
out.mDesc.GetLengths()[1],
|
out.mDesc.GetLengths()[1],
|
||||||
out.mDesc.GetLengths()[2],
|
out.mDesc.GetLengths()[2],
|
||||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||||
break;
|
}
|
||||||
case ConvTensorLayout::NHWC:
|
else if(layout == ConvTensorLayout::NHWC)
|
||||||
|
{
|
||||||
make_ParallelTensorFunctor(f_nhwc,
|
make_ParallelTensorFunctor(f_nhwc,
|
||||||
out.mDesc.GetLengths()[0],
|
out.mDesc.GetLengths()[0],
|
||||||
out.mDesc.GetLengths()[1],
|
out.mDesc.GetLengths()[1],
|
||||||
out.mDesc.GetLengths()[2],
|
out.mDesc.GetLengths()[2],
|
||||||
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
||||||
break;
|
}
|
||||||
default: throw std::runtime_error("wrong! not supported layout");
|
else
|
||||||
|
{
|
||||||
|
throw std::runtime_error("wrong! not supported layout");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,17 +100,15 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
|||||||
constexpr std::size_t HoPerTile = 2;
|
constexpr std::size_t HoPerTile = 2;
|
||||||
constexpr std::size_t WoPerTile = 2;
|
constexpr std::size_t WoPerTile = 2;
|
||||||
|
|
||||||
std::size_t N = in_nchw.mDesc.GetLengths()[0];
|
std::size_t N = in_nchw.mDesc.GetLengths()[0];
|
||||||
std::size_t C = in_nchw.mDesc.GetLengths()[1];
|
std::size_t C = in_nchw.mDesc.GetLengths()[1];
|
||||||
std::size_t HI = in_nchw.mDesc.GetLengths()[2];
|
|
||||||
std::size_t WI = in_nchw.mDesc.GetLengths()[3];
|
|
||||||
|
|
||||||
std::size_t K = wei_kcyx.mDesc.GetLengths()[0];
|
std::size_t K = wei_kcyx.mDesc.GetLengths()[0];
|
||||||
std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
|
std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
|
||||||
std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
|
std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
|
||||||
|
|
||||||
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
|
std::size_t Ho = out_nkhw.mDesc.GetLengths()[2];
|
||||||
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
|
std::size_t Wo = out_nkhw.mDesc.GetLengths()[3];
|
||||||
|
|
||||||
index_t h_pad_low = InLeftPads{}.Get(Number<0>{});
|
index_t h_pad_low = InLeftPads{}.Get(Number<0>{});
|
||||||
index_t w_pad_low = InLeftPads{}.Get(Number<1>{});
|
index_t w_pad_low = InLeftPads{}.Get(Number<1>{});
|
||||||
@@ -118,8 +116,8 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
|||||||
std::size_t HiPerTile = HoPerTile + Y - 1;
|
std::size_t HiPerTile = HoPerTile + Y - 1;
|
||||||
std::size_t WiPerTile = WoPerTile + X - 1;
|
std::size_t WiPerTile = WoPerTile + X - 1;
|
||||||
|
|
||||||
std::size_t HTile = (HO + HoPerTile - 1) / HoPerTile;
|
std::size_t HTile = (Ho + HoPerTile - 1) / HoPerTile;
|
||||||
std::size_t WTile = (WO + WoPerTile - 1) / WoPerTile;
|
std::size_t WTile = (Wo + WoPerTile - 1) / WoPerTile;
|
||||||
|
|
||||||
Tensor<double> in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
Tensor<double> in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
||||||
Tensor<double> in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
Tensor<double> in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile});
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user