mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Composable kernel init integration v3 (#1097)
* Squashed 'src/composable_kernel/' content from commitf6edda611git-subtree-dir: src/composable_kernel git-subtree-split:f6edda6119* add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files * Squashed 'src/composable_kernel/' changes from f6edda611..5781adf5c5781adf5cUpdate develop (#5) (#6)97e6d514fMerge pull request #4 from ROCmSoftwarePlatform/separate_online_compile7b1ec41e5refactor49c33aaearefactor54b3e73d1rename git-subtree-dir: src/composable_kernel git-subtree-split:5781adf5cf* fix * refactor * remove online compilation from CK * refactor * fix * add ctest * add c-style pointer cast * vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast * fix clang warning suppression * tidy * suppress cppcheck * fix enum issue * revert chagnes to hip build * fix kernel filename * update CK build script * rename * rename * make innner product compatiable on gfx900 * Update src/include/miopen/solver/ck_utility_common.hpp Co-authored-by: JD <Jehandad.Khan@amd.com> * compiler parameter use stream * use int instead of index_t in kernel wrapper * DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element * refactor * refactor * change cmakelist * change ck common utility * fix Co-authored-by: JD <Jehandad.Khan@amd.com> [ROCm/composable_kernel commit:6fe3627a9e]
This commit is contained in:
90
.clang-format
Normal file
90
.clang-format
Normal file
@@ -0,0 +1,90 @@
|
||||
---
|
||||
Language: Cpp
|
||||
AccessModifierOffset: 0
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignConsecutiveAssignments: true
|
||||
AlignConsecutiveDeclarations: false
|
||||
AlignEscapedNewlinesLeft: true
|
||||
AlignOperands: true
|
||||
AlignTrailingComments: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: true
|
||||
AllowShortCaseLabelsOnASingleLine: true
|
||||
AllowShortFunctionsOnASingleLine: All
|
||||
AllowShortIfStatementsOnASingleLine: false
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakAfterDefinitionReturnType: None
|
||||
AlwaysBreakAfterReturnType: None
|
||||
AlwaysBreakBeforeMultilineStrings: false
|
||||
AlwaysBreakTemplateDeclarations: true
|
||||
BinPackArguments: false
|
||||
BinPackParameters: false
|
||||
BraceWrapping:
|
||||
AfterClass: true
|
||||
AfterControlStatement: true
|
||||
AfterEnum: true
|
||||
AfterFunction: true
|
||||
AfterNamespace: false
|
||||
AfterObjCDeclaration: true
|
||||
AfterStruct: true
|
||||
AfterUnion: true
|
||||
BeforeCatch: true
|
||||
BeforeElse: true
|
||||
IndentBraces: false
|
||||
BreakBeforeBinaryOperators: None
|
||||
BreakBeforeBraces: Custom
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializersBeforeComma: false
|
||||
ColumnLimit: 100
|
||||
CommentPragmas: '^ IWYU pragma:'
|
||||
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
||||
ConstructorInitializerIndentWidth: 4
|
||||
ContinuationIndentWidth: 4
|
||||
Cpp11BracedListStyle: true
|
||||
DerivePointerAlignment: false
|
||||
DisableFormat: false
|
||||
ExperimentalAutoDetectBinPacking: false
|
||||
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
|
||||
IncludeCategories:
|
||||
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
|
||||
Priority: 2
|
||||
- Regex: '^(<|"(gtest|isl|json)/)'
|
||||
Priority: 3
|
||||
- Regex: '.*'
|
||||
Priority: 1
|
||||
IndentCaseLabels: false
|
||||
IndentWidth: 4
|
||||
IndentWrappedFunctionNames: false
|
||||
KeepEmptyLinesAtTheStartOfBlocks: true
|
||||
MacroBlockBegin: ''
|
||||
MacroBlockEnd: ''
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
ObjCBlockIndentWidth: 2
|
||||
ObjCSpaceAfterProperty: false
|
||||
ObjCSpaceBeforeProtocolList: true
|
||||
PenaltyBreakBeforeFirstCallParameter: 19
|
||||
PenaltyBreakComment: 300
|
||||
PenaltyBreakFirstLessLess: 120
|
||||
PenaltyBreakString: 1000
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 60
|
||||
PointerAlignment: Left
|
||||
ReflowComments: true
|
||||
SortIncludes: false
|
||||
SpaceAfterCStyleCast: false
|
||||
# SpaceAfterTemplateKeyword: true
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeParens: Never
|
||||
SpaceInEmptyParentheses: false
|
||||
SpacesBeforeTrailingComments: 1
|
||||
SpacesInAngles: false
|
||||
SpacesInContainerLiterals: true
|
||||
SpacesInCStyleCastParentheses: false
|
||||
SpacesInParentheses: false
|
||||
SpacesInSquareBrackets: false
|
||||
Standard: Cpp11
|
||||
TabWidth: 8
|
||||
UseTab: Never
|
||||
...
|
||||
|
||||
3
.clang-tidy
Normal file
3
.clang-tidy
Normal file
@@ -0,0 +1,3 @@
|
||||
CheckOptions:
|
||||
- key: bugprone-reserved-identifier.AllowedIdentifiers
|
||||
value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__'
|
||||
198
CMakeLists.txt
Normal file
198
CMakeLists.txt
Normal file
@@ -0,0 +1,198 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
project(composable_kernel)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
|
||||
include(CheckCXXCompilerFlag)
|
||||
|
||||
## C++
|
||||
enable_language(CXX)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
|
||||
|
||||
## OpenMP
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
|
||||
# workaround issue hipcc in rocm3.5 cannot find openmp
|
||||
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
|
||||
set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument")
|
||||
set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5")
|
||||
set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES})
|
||||
set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES})
|
||||
set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES})
|
||||
else()
|
||||
find_package(OpenMP REQUIRED)
|
||||
endif()
|
||||
|
||||
message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
|
||||
message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
|
||||
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
|
||||
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
|
||||
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
link_libraries(${OpenMP_gomp_LIBRARY})
|
||||
link_libraries(${OpenMP_pthread_LIBRARY})
|
||||
|
||||
## HIP
|
||||
find_package(HIP REQUIRED)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
## half
|
||||
#find_path(HALF_INCLUDE_DIR half.hpp)
|
||||
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
|
||||
|
||||
# CMAKE_CXX_FLAGS
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
## tidy
|
||||
include(EnableCompilerWarnings)
|
||||
set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
|
||||
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
|
||||
set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
|
||||
# Enable tidy on hip
|
||||
elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU")
|
||||
set(MIOPEN_TIDY_ERRORS ALL)
|
||||
endif()
|
||||
|
||||
include(ClangTidy)
|
||||
enable_clang_tidy(
|
||||
CHECKS
|
||||
*
|
||||
-abseil-*
|
||||
-android-cloexec-fopen
|
||||
# Yea we shouldn't be using rand()
|
||||
-cert-msc30-c
|
||||
-bugprone-exception-escape
|
||||
-bugprone-macro-parentheses
|
||||
-cert-env33-c
|
||||
-cert-msc32-c
|
||||
-cert-msc50-cpp
|
||||
-cert-msc51-cpp
|
||||
-cert-dcl37-c
|
||||
-cert-dcl51-cpp
|
||||
-clang-analyzer-alpha.core.CastToStruct
|
||||
-clang-analyzer-optin.performance.Padding
|
||||
-clang-diagnostic-deprecated-declarations
|
||||
-clang-diagnostic-extern-c-compat
|
||||
-clang-diagnostic-unused-command-line-argument
|
||||
-cppcoreguidelines-avoid-c-arrays
|
||||
-cppcoreguidelines-avoid-magic-numbers
|
||||
-cppcoreguidelines-explicit-virtual-functions
|
||||
-cppcoreguidelines-init-variables
|
||||
-cppcoreguidelines-macro-usage
|
||||
-cppcoreguidelines-non-private-member-variables-in-classes
|
||||
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
|
||||
-cppcoreguidelines-pro-bounds-constant-array-index
|
||||
-cppcoreguidelines-pro-bounds-pointer-arithmetic
|
||||
-cppcoreguidelines-pro-type-member-init
|
||||
-cppcoreguidelines-pro-type-reinterpret-cast
|
||||
-cppcoreguidelines-pro-type-union-access
|
||||
-cppcoreguidelines-pro-type-vararg
|
||||
-cppcoreguidelines-special-member-functions
|
||||
-fuchsia-*
|
||||
-google-explicit-constructor
|
||||
-google-readability-braces-around-statements
|
||||
-google-readability-todo
|
||||
-google-runtime-int
|
||||
-google-runtime-references
|
||||
-hicpp-vararg
|
||||
-hicpp-braces-around-statements
|
||||
-hicpp-explicit-conversions
|
||||
-hicpp-named-parameter
|
||||
-hicpp-no-array-decay
|
||||
# We really shouldn't use bitwise operators with signed integers, but
|
||||
# opencl leaves us no choice
|
||||
-hicpp-avoid-c-arrays
|
||||
-hicpp-signed-bitwise
|
||||
-hicpp-special-member-functions
|
||||
-hicpp-uppercase-literal-suffix
|
||||
-hicpp-use-auto
|
||||
-hicpp-use-equals-default
|
||||
-hicpp-use-override
|
||||
-llvm-header-guard
|
||||
-llvm-include-order
|
||||
#-llvmlibc-*
|
||||
-llvmlibc-restrict-system-libc-headers
|
||||
-llvmlibc-callee-namespace
|
||||
-llvmlibc-implementation-in-namespace
|
||||
-llvm-else-after-return
|
||||
-llvm-qualified-auto
|
||||
-misc-misplaced-const
|
||||
-misc-non-private-member-variables-in-classes
|
||||
-misc-no-recursion
|
||||
-modernize-avoid-bind
|
||||
-modernize-avoid-c-arrays
|
||||
-modernize-pass-by-value
|
||||
-modernize-use-auto
|
||||
-modernize-use-default-member-init
|
||||
-modernize-use-equals-default
|
||||
-modernize-use-trailing-return-type
|
||||
-modernize-use-transparent-functors
|
||||
-performance-unnecessary-value-param
|
||||
-readability-braces-around-statements
|
||||
-readability-else-after-return
|
||||
# we are not ready to use it, but very useful
|
||||
-readability-function-cognitive-complexity
|
||||
-readability-isolate-declaration
|
||||
-readability-magic-numbers
|
||||
-readability-named-parameter
|
||||
-readability-uppercase-literal-suffix
|
||||
-readability-convert-member-functions-to-static
|
||||
-readability-qualified-auto
|
||||
-readability-redundant-string-init
|
||||
# too many narrowing conversions in our code
|
||||
-bugprone-narrowing-conversions
|
||||
-cppcoreguidelines-narrowing-conversions
|
||||
-altera-struct-pack-align
|
||||
-cppcoreguidelines-prefer-member-initializer
|
||||
|
||||
${MIOPEN_TIDY_CHECKS}
|
||||
${MIOPEN_TIDY_ERRORS}
|
||||
HEADER_FILTER
|
||||
"\.hpp$"
|
||||
EXTRA_ARGS
|
||||
-DMIOPEN_USE_CLANG_TIDY
|
||||
)
|
||||
|
||||
include(CppCheck)
|
||||
enable_cppcheck(
|
||||
CHECKS
|
||||
warning
|
||||
style
|
||||
performance
|
||||
portability
|
||||
SUPPRESS
|
||||
ConfigurationNotChecked
|
||||
constStatement
|
||||
duplicateCondition
|
||||
noExplicitConstructor
|
||||
passedByValue
|
||||
preprocessorErrorDirective
|
||||
shadowVariable
|
||||
unusedFunction
|
||||
unusedPrivateFunction
|
||||
unusedStructMember
|
||||
unmatchedSuppression
|
||||
FORCE
|
||||
SOURCES
|
||||
host/host_tensor/src
|
||||
host/driver_offline/src
|
||||
composable_kernel/src/kernel_wrapper
|
||||
INCLUDE
|
||||
host/host_tensor/include
|
||||
host/solver/include
|
||||
host/driver_offline/include
|
||||
composable_kernel/include/*
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/include
|
||||
${CMAKE_CURRENT_BINARY_DIR}/include
|
||||
DEFINE
|
||||
CPPCHECK=1
|
||||
__linux__=1
|
||||
)
|
||||
|
||||
add_subdirectory(host)
|
||||
177
README.md
Normal file
177
README.md
Normal file
@@ -0,0 +1,177 @@
|
||||
# How to build and run
|
||||
|
||||
# Docker
|
||||
```
|
||||
docker run \
|
||||
-it \
|
||||
--rm \
|
||||
--privileged \
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/tensorflow:rocm4.2-tf2.4-dev \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
# Install Boost for online compilation
|
||||
https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install
|
||||
|
||||
|
||||
# Build
|
||||
Add path of Boost
|
||||
```
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
|
||||
```
|
||||
|
||||
```
|
||||
mkdir build && cd build
|
||||
```
|
||||
|
||||
cmake cmd. Need to Specify target ID, example below is gfx908
|
||||
```
|
||||
cmake \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
|
||||
-D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
..
|
||||
```
|
||||
|
||||
Build drivers: \
|
||||
``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \
|
||||
``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \
|
||||
``conv_fwd_driver_online`` is (online compilation) driver for forward convolution
|
||||
```
|
||||
make -j conv_fwd_driver_offline
|
||||
make -j conv_bwd_driver_offline
|
||||
make -j conv_fwd_driver_online
|
||||
```
|
||||
|
||||
# Run
|
||||
* layout: 0 = NCHW; 1 = NHWC
|
||||
* algo: algorithm
|
||||
* verify: 0 = no verification; 1 = do verification
|
||||
* init: 0 ~ 5. initialization method
|
||||
* log: 0 = no log; 1 = do log
|
||||
* repeat: number of time kernel being launched
|
||||
```
|
||||
######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
|
||||
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
```
|
||||
|
||||
# Result
|
||||
Forward convoltuion, FP16, NCHW
|
||||
```
|
||||
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
|
||||
layout: 0
|
||||
in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1}
|
||||
wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1}
|
||||
out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {2, 2, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||
a_k0_m_k1_grid_desc{216, 256, 8}
|
||||
b_k0_n_k1_grid_desc{216, 165888, 8}
|
||||
c_m_n_grid_desc{ 256, 165888}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 1.4155 ms, 103.686 TFlop/s
|
||||
```
|
||||
|
||||
Forward convoltuion, FP16, NCHW
|
||||
```
|
||||
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
layout: 0
|
||||
in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1}
|
||||
wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1}
|
||||
out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
|
||||
a_k0_m_k1_grid_desc{288, 1024, 8}
|
||||
b_k0_n_k1_grid_desc{288, 50176, 8}
|
||||
c_m_n_grid_desc{ 1024, 50176}
|
||||
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 2.21357 ms, 106.959 TFlop/s
|
||||
```
|
||||
|
||||
Forward convolution, FP16, NHWC
|
||||
```
|
||||
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
|
||||
|
||||
layout: 1
|
||||
in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1}
|
||||
wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1}
|
||||
out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {2, 2, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{216, 165888, 8}
|
||||
b_k0_n_k1_grid_desc{216, 256, 8}
|
||||
c_m_n_grid_desc{ 165888, 256}
|
||||
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 1.12014 ms, 131.025 TFlop/s
|
||||
```
|
||||
|
||||
Forward convolution, FP16, NHWC
|
||||
```
|
||||
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
layout: 1
|
||||
in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
|
||||
wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1}
|
||||
out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{288, 50176, 8}
|
||||
b_k0_n_k1_grid_desc{288, 1024, 8}
|
||||
c_m_n_grid_desc{ 50176, 1024}
|
||||
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 1.86877 ms, 126.693 TFlop/s
|
||||
```
|
||||
|
||||
Backward data convolution, FP16, NHWC
|
||||
```
|
||||
./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
|
||||
|
||||
layout: 1
|
||||
in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
|
||||
wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1}
|
||||
out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
|
||||
InLeftPads size 2, {1, 1, }
|
||||
InRightPads size 2, {1, 1, }
|
||||
ConvStrides size 2, {1, 1, }
|
||||
ConvDilations size 2, {1, 1, }
|
||||
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
|
||||
a_k0_m_k1_grid_desc{288, 50176, 8}
|
||||
b_k0_n_k1_grid_desc{288, 1024, 8}
|
||||
c_m_n_grid_desc{ 50176, 1024}
|
||||
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
|
||||
Warm up
|
||||
Start running 1 times...
|
||||
Average time : 2.22461 ms, 106.428 TFlop/s
|
||||
```
|
||||
34
cmake/Analyzers.cmake
Normal file
34
cmake/Analyzers.cmake
Normal file
@@ -0,0 +1,34 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
if(NOT TARGET analyze)
|
||||
add_custom_target(analyze)
|
||||
endif()
|
||||
|
||||
function(mark_as_analyzer)
|
||||
add_dependencies(analyze ${ARGN})
|
||||
endfunction()
|
||||
|
||||
162
cmake/ClangTidy.cmake
Normal file
162
cmake/ClangTidy.cmake
Normal file
@@ -0,0 +1,162 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
include(CMakeParseArguments)
|
||||
include(Analyzers)
|
||||
|
||||
get_filename_component(CLANG_TIDY_EXE_HINT "${CMAKE_CXX_COMPILER}" PATH)
|
||||
|
||||
find_program(CLANG_TIDY_EXE
|
||||
NAMES
|
||||
clang-tidy
|
||||
clang-tidy-5.0
|
||||
clang-tidy-4.0
|
||||
clang-tidy-3.9
|
||||
clang-tidy-3.8
|
||||
clang-tidy-3.7
|
||||
clang-tidy-3.6
|
||||
clang-tidy-3.5
|
||||
HINTS
|
||||
${CLANG_TIDY_EXE_HINT}
|
||||
PATH_SUFFIXES
|
||||
compiler/bin
|
||||
PATHS
|
||||
/opt/rocm/llvm/bin
|
||||
/opt/rocm/hcc
|
||||
/usr/local/opt/llvm/bin
|
||||
)
|
||||
|
||||
function(find_clang_tidy_version VAR)
|
||||
execute_process(COMMAND ${CLANG_TIDY_EXE} -version OUTPUT_VARIABLE VERSION_OUTPUT)
|
||||
separate_arguments(VERSION_OUTPUT_LIST UNIX_COMMAND "${VERSION_OUTPUT}")
|
||||
list(FIND VERSION_OUTPUT_LIST "version" VERSION_INDEX)
|
||||
if(VERSION_INDEX GREATER 0)
|
||||
math(EXPR VERSION_INDEX "${VERSION_INDEX} + 1")
|
||||
list(GET VERSION_OUTPUT_LIST ${VERSION_INDEX} VERSION)
|
||||
set(${VAR} ${VERSION} PARENT_SCOPE)
|
||||
else()
|
||||
set(${VAR} "0.0" PARENT_SCOPE)
|
||||
endif()
|
||||
|
||||
endfunction()
|
||||
|
||||
if( NOT CLANG_TIDY_EXE )
|
||||
message( STATUS "Clang tidy not found" )
|
||||
set(CLANG_TIDY_VERSION "0.0")
|
||||
else()
|
||||
find_clang_tidy_version(CLANG_TIDY_VERSION)
|
||||
message( STATUS "Clang tidy found: ${CLANG_TIDY_VERSION}")
|
||||
endif()
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
set(CLANG_TIDY_FIXIT_DIR ${CMAKE_BINARY_DIR}/fixits)
|
||||
file(MAKE_DIRECTORY ${CLANG_TIDY_FIXIT_DIR})
|
||||
set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CLANG_TIDY_FIXIT_DIR})
|
||||
|
||||
macro(enable_clang_tidy)
|
||||
set(options ANALYZE_TEMPORARY_DTORS ALL)
|
||||
set(oneValueArgs HEADER_FILTER)
|
||||
set(multiValueArgs CHECKS ERRORS EXTRA_ARGS)
|
||||
|
||||
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
string(REPLACE ";" "," CLANG_TIDY_CHECKS "${PARSE_CHECKS}")
|
||||
string(REPLACE ";" "," CLANG_TIDY_ERRORS "${PARSE_ERRORS}")
|
||||
set(CLANG_TIDY_EXTRA_ARGS)
|
||||
foreach(ARG ${PARSE_EXTRA_ARGS})
|
||||
list(APPEND CLANG_TIDY_EXTRA_ARGS "-extra-arg=${ARG}")
|
||||
endforeach()
|
||||
|
||||
set(CLANG_TIDY_ALL)
|
||||
if(PARSE_ALL)
|
||||
set(CLANG_TIDY_ALL ALL)
|
||||
endif()
|
||||
|
||||
message(STATUS "Clang tidy checks: ${CLANG_TIDY_CHECKS}")
|
||||
|
||||
if (${PARSE_ANALYZE_TEMPORARY_DTORS})
|
||||
set(CLANG_TIDY_ANALYZE_TEMPORARY_DTORS "-analyze-temporary-dtors")
|
||||
endif()
|
||||
|
||||
if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0")
|
||||
set(CLANG_TIDY_ERRORS_ARG "")
|
||||
else()
|
||||
set(CLANG_TIDY_ERRORS_ARG "-warnings-as-errors='${CLANG_TIDY_ERRORS}'")
|
||||
endif()
|
||||
|
||||
if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0")
|
||||
set(CLANG_TIDY_QUIET_ARG "")
|
||||
else()
|
||||
set(CLANG_TIDY_QUIET_ARG "-quiet")
|
||||
endif()
|
||||
|
||||
if(PARSE_HEADER_FILTER)
|
||||
string(REPLACE "$" "$$" CLANG_TIDY_HEADER_FILTER "${PARSE_HEADER_FILTER}")
|
||||
else()
|
||||
set(CLANG_TIDY_HEADER_FILTER ".*")
|
||||
endif()
|
||||
|
||||
set(CLANG_TIDY_COMMAND
|
||||
${CLANG_TIDY_EXE}
|
||||
${CLANG_TIDY_QUIET_ARG}
|
||||
-p ${CMAKE_BINARY_DIR}
|
||||
-checks='${CLANG_TIDY_CHECKS}'
|
||||
${CLANG_TIDY_ERRORS_ARG}
|
||||
${CLANG_TIDY_EXTRA_ARGS}
|
||||
${CLANG_TIDY_ANALYZE_TEMPORARY_DTORS}
|
||||
-header-filter='${CLANG_TIDY_HEADER_FILTER}'
|
||||
)
|
||||
add_custom_target(tidy ${CLANG_TIDY_ALL})
|
||||
mark_as_analyzer(tidy)
|
||||
add_custom_target(tidy-base)
|
||||
add_custom_target(tidy-make-fixit-dir COMMAND ${CMAKE_COMMAND} -E make_directory ${CLANG_TIDY_FIXIT_DIR})
|
||||
add_custom_target(tidy-rm-fixit-dir COMMAND ${CMAKE_COMMAND} -E remove_directory ${CLANG_TIDY_FIXIT_DIR})
|
||||
add_dependencies(tidy-make-fixit-dir tidy-rm-fixit-dir)
|
||||
add_dependencies(tidy-base tidy-make-fixit-dir)
|
||||
endmacro()
|
||||
|
||||
function(clang_tidy_check TARGET)
|
||||
get_target_property(SOURCES ${TARGET} SOURCES)
|
||||
# TODO: Use generator expressions instead
|
||||
# COMMAND ${CLANG_TIDY_COMMAND} $<TARGET_PROPERTY:${TARGET},SOURCES>
|
||||
# COMMAND ${CLANG_TIDY_COMMAND} $<JOIN:$<TARGET_PROPERTY:${TARGET},SOURCES>, >
|
||||
foreach(SOURCE ${SOURCES})
|
||||
if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS"))
|
||||
string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file)
|
||||
set(tidy_target tidy-target-${TARGET}-${tidy_file})
|
||||
add_custom_target(${tidy_target}
|
||||
# for some targets clang-tidy not able to get information from .clang-tidy
|
||||
DEPENDS ${SOURCE}
|
||||
COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml"
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..."
|
||||
)
|
||||
add_dependencies(${tidy_target} ${TARGET})
|
||||
add_dependencies(${tidy_target} tidy-base)
|
||||
add_dependencies(tidy ${tidy_target})
|
||||
endif()
|
||||
endforeach()
|
||||
endfunction()
|
||||
|
||||
130
cmake/CppCheck.cmake
Normal file
130
cmake/CppCheck.cmake
Normal file
@@ -0,0 +1,130 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
include(CMakeParseArguments)
|
||||
include(ProcessorCount)
|
||||
include(Analyzers)
|
||||
|
||||
find_program(CPPCHECK_EXE
|
||||
NAMES
|
||||
cppcheck
|
||||
PATHS
|
||||
/opt/rocm/bin
|
||||
)
|
||||
|
||||
ProcessorCount(CPPCHECK_JOBS)
|
||||
|
||||
set(CPPCHECK_BUILD_DIR ${CMAKE_BINARY_DIR}/cppcheck-build)
|
||||
file(MAKE_DIRECTORY ${CPPCHECK_BUILD_DIR})
|
||||
set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CPPCHECK_BUILD_DIR})
|
||||
|
||||
macro(enable_cppcheck)
|
||||
set(options FORCE)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs CHECKS SUPPRESS DEFINE UNDEFINE INCLUDE SOURCES)
|
||||
|
||||
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
string(REPLACE ";" "," CPPCHECK_CHECKS "${PARSE_CHECKS}")
|
||||
string(REPLACE ";" "\n" CPPCHECK_SUPPRESS "${PARSE_SUPPRESS};*:/usr/*")
|
||||
file(WRITE ${CMAKE_BINARY_DIR}/cppcheck-supressions "${CPPCHECK_SUPPRESS}")
|
||||
set(CPPCHECK_DEFINES)
|
||||
foreach(DEF ${PARSE_DEFINE})
|
||||
set(CPPCHECK_DEFINES "${CPPCHECK_DEFINES} -D${DEF}")
|
||||
endforeach()
|
||||
|
||||
set(CPPCHECK_UNDEFINES)
|
||||
foreach(DEF ${PARSE_UNDEFINE})
|
||||
set(CPPCHECK_UNDEFINES "${CPPCHECK_UNDEFINES} -U${DEF}")
|
||||
endforeach()
|
||||
|
||||
set(CPPCHECK_INCLUDES)
|
||||
foreach(INC ${PARSE_INCLUDE})
|
||||
set(CPPCHECK_INCLUDES "${CPPCHECK_INCLUDES} -I${INC}")
|
||||
endforeach()
|
||||
|
||||
# set(CPPCHECK_FORCE)
|
||||
set(CPPCHECK_FORCE "--project=${CMAKE_BINARY_DIR}/compile_commands.json")
|
||||
if(PARSE_FORCE)
|
||||
set(CPPCHECK_FORCE --force)
|
||||
endif()
|
||||
|
||||
set(SOURCES)
|
||||
set(GLOBS)
|
||||
foreach(SOURCE ${PARSE_SOURCES})
|
||||
get_filename_component(ABS_SOURCE ${SOURCE} ABSOLUTE)
|
||||
if(EXISTS ${ABS_SOURCE})
|
||||
if(IS_DIRECTORY ${ABS_SOURCE})
|
||||
set(GLOBS "${GLOBS} ${ABS_SOURCE}/*.cpp ${ABS_SOURCE}/*.hpp ${ABS_SOURCE}/*.cxx ${ABS_SOURCE}/*.c ${ABS_SOURCE}/*.h")
|
||||
else()
|
||||
set(SOURCES "${SOURCES} ${ABS_SOURCE}")
|
||||
endif()
|
||||
else()
|
||||
set(GLOBS "${GLOBS} ${ABS_SOURCE}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
file(WRITE ${CMAKE_BINARY_DIR}/cppcheck.cmake "
|
||||
file(GLOB_RECURSE GSRCS ${GLOBS})
|
||||
set(CPPCHECK_COMMAND
|
||||
${CPPCHECK_EXE}
|
||||
-q
|
||||
# -v
|
||||
# --report-progress
|
||||
${CPPCHECK_FORCE}
|
||||
--cppcheck-build-dir=${CPPCHECK_BUILD_DIR}
|
||||
--platform=native
|
||||
--template=gcc
|
||||
--error-exitcode=1
|
||||
-j ${CPPCHECK_JOBS}
|
||||
${CPPCHECK_DEFINES}
|
||||
${CPPCHECK_UNDEFINES}
|
||||
${CPPCHECK_INCLUDES}
|
||||
--enable=${CPPCHECK_CHECKS}
|
||||
--inline-suppr
|
||||
--suppressions-list=${CMAKE_BINARY_DIR}/cppcheck-supressions
|
||||
${SOURCES} \${GSRCS}
|
||||
)
|
||||
string(REPLACE \";\" \" \" CPPCHECK_SHOW_COMMAND \"\${CPPCHECK_COMMAND}\")
|
||||
message(\"\${CPPCHECK_SHOW_COMMAND}\")
|
||||
execute_process(
|
||||
COMMAND \${CPPCHECK_COMMAND}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE RESULT
|
||||
)
|
||||
if(NOT RESULT EQUAL 0)
|
||||
message(FATAL_ERROR \"Cppcheck failed\")
|
||||
endif()
|
||||
")
|
||||
|
||||
add_custom_target(cppcheck
|
||||
COMMAND ${CMAKE_COMMAND} -P ${CMAKE_BINARY_DIR}/cppcheck.cmake
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMENT "cppcheck: Running cppcheck..."
|
||||
)
|
||||
mark_as_analyzer(cppcheck)
|
||||
endmacro()
|
||||
|
||||
|
||||
355
cmake/DoxygenDoc.cmake
Normal file
355
cmake/DoxygenDoc.cmake
Normal file
@@ -0,0 +1,355 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
include(CMakeParseArguments)
|
||||
include(MainDoc)
|
||||
|
||||
find_program(DOXYGEN_EXECUTABLE NAMES doxygen
|
||||
PATH_SUFFIXES bin
|
||||
DOC "Doxygen documentation generator"
|
||||
)
|
||||
mark_as_advanced(DOXYGEN_EXECUTABLE)
|
||||
|
||||
find_path(DOT_EXECUTABLE NAMES dot
|
||||
PATH_SUFFIXES bin
|
||||
DOC "Graphviz"
|
||||
)
|
||||
mark_as_advanced(DOT_EXECUTABLE)
|
||||
|
||||
set(DOXYGEN_ARGS
|
||||
ABBREVIATE_BRIEF
|
||||
ALIASES
|
||||
ALLEXTERNALS
|
||||
ALLOW_UNICODE_NAMES
|
||||
ALPHABETICAL_INDEX
|
||||
ALWAYS_DETAILED_SEC
|
||||
AUTOLINK_SUPPORT
|
||||
BINARY_TOC
|
||||
BRIEF_MEMBER_DESC
|
||||
BUILTIN_STL_SUPPORT
|
||||
CALLER_GRAPH
|
||||
CALL_GRAPH
|
||||
CASE_SENSE_NAMES
|
||||
CHM_FILE
|
||||
CHM_INDEX_ENCODING
|
||||
CITE_BIB_FILES
|
||||
CLANG_ASSISTED_PARSING
|
||||
CLANG_OPTIONS
|
||||
CLASS_DIAGRAMS
|
||||
CLASS_GRAPH
|
||||
COLLABORATION_GRAPH
|
||||
COLS_IN_ALPHA_INDEX
|
||||
COMPACT_LATEX
|
||||
COMPACT_RTF
|
||||
CPP_CLI_SUPPORT
|
||||
CREATE_SUBDIRS
|
||||
DIAFILE_DIRS
|
||||
DIA_PATH
|
||||
DIRECTORY_GRAPH
|
||||
DISABLE_INDEX
|
||||
DISTRIBUTE_GROUP_DOC
|
||||
DOCBOOK_OUTPUT
|
||||
DOCBOOK_PROGRAMLISTING
|
||||
DOCSET_BUNDLE_ID
|
||||
DOCSET_FEEDNAME
|
||||
DOCSET_PUBLISHER_ID
|
||||
DOCSET_PUBLISHER_NAME
|
||||
DOTFILE_DIRS
|
||||
DOT_CLEANUP
|
||||
DOT_FONTNAME
|
||||
DOT_FONTPATH
|
||||
DOT_FONTSIZE
|
||||
DOT_GRAPH_MAX_NODES
|
||||
DOT_IMAGE_FORMAT
|
||||
DOT_MULTI_TARGETS
|
||||
DOT_NUM_THREADS
|
||||
# DOT_PATH
|
||||
DOT_TRANSPARENT
|
||||
DOXYFILE_ENCODING
|
||||
ECLIPSE_DOC_ID
|
||||
ENABLED_SECTIONS
|
||||
ENABLE_PREPROCESSING
|
||||
ENUM_VALUES_PER_LINE
|
||||
EXAMPLE_PATH
|
||||
EXAMPLE_PATTERNS
|
||||
EXAMPLE_RECURSIVE
|
||||
EXCLUDE
|
||||
EXCLUDE_PATTERNS
|
||||
EXCLUDE_SYMBOLS
|
||||
EXCLUDE_SYMLINKS
|
||||
EXPAND_AS_DEFINED
|
||||
EXPAND_ONLY_PREDEF
|
||||
EXTENSION_MAPPING
|
||||
EXTERNAL_GROUPS
|
||||
EXTERNAL_PAGES
|
||||
EXTERNAL_SEARCH
|
||||
EXTERNAL_SEARCH_ID
|
||||
EXTRACT_ALL
|
||||
EXTRACT_ANON_NSPACES
|
||||
EXTRACT_LOCAL_CLASSES
|
||||
EXTRACT_LOCAL_METHODS
|
||||
EXTRACT_PACKAGE
|
||||
EXTRACT_PRIVATE
|
||||
EXTRACT_STATIC
|
||||
EXTRA_PACKAGES
|
||||
EXTRA_SEARCH_MAPPINGS
|
||||
EXT_LINKS_IN_WINDOW
|
||||
FILE_PATTERNS
|
||||
FILE_VERSION_FILTER
|
||||
FILTER_PATTERNS
|
||||
FILTER_SOURCE_FILES
|
||||
FILTER_SOURCE_PATTERNS
|
||||
FORCE_LOCAL_INCLUDES
|
||||
FORMULA_FONTSIZE
|
||||
FORMULA_TRANSPARENT
|
||||
FULL_PATH_NAMES
|
||||
GENERATE_AUTOGEN_DEF
|
||||
GENERATE_BUGLIST
|
||||
GENERATE_CHI
|
||||
GENERATE_DEPRECATEDLIST
|
||||
GENERATE_DOCBOOK
|
||||
GENERATE_DOCSET
|
||||
GENERATE_ECLIPSEHELP
|
||||
GENERATE_HTML
|
||||
GENERATE_HTMLHELP
|
||||
GENERATE_LATEX
|
||||
GENERATE_LEGEND
|
||||
GENERATE_MAN
|
||||
GENERATE_PERLMOD
|
||||
GENERATE_QHP
|
||||
GENERATE_RTF
|
||||
GENERATE_TAGFILE
|
||||
GENERATE_TESTLIST
|
||||
GENERATE_TODOLIST
|
||||
GENERATE_TREEVIEW
|
||||
GENERATE_XML
|
||||
GRAPHICAL_HIERARCHY
|
||||
GROUP_GRAPHS
|
||||
GROUP_NESTED_COMPOUNDS
|
||||
# HAVE_DOT
|
||||
HHC_LOCATION
|
||||
HIDE_COMPOUND_REFERENCE
|
||||
HIDE_FRIEND_COMPOUNDS
|
||||
HIDE_IN_BODY_DOCS
|
||||
HIDE_SCOPE_NAMES
|
||||
HIDE_UNDOC_CLASSES
|
||||
HIDE_UNDOC_MEMBERS
|
||||
HIDE_UNDOC_RELATIONS
|
||||
HTML_COLORSTYLE_GAMMA
|
||||
HTML_COLORSTYLE_HUE
|
||||
HTML_COLORSTYLE_SAT
|
||||
HTML_DYNAMIC_SECTIONS
|
||||
HTML_EXTRA_FILES
|
||||
HTML_EXTRA_STYLESHEET
|
||||
HTML_FILE_EXTENSION
|
||||
HTML_FOOTER
|
||||
HTML_HEADER
|
||||
HTML_INDEX_NUM_ENTRIES
|
||||
HTML_OUTPUT
|
||||
HTML_STYLESHEET
|
||||
HTML_TIMESTAMP
|
||||
IDL_PROPERTY_SUPPORT
|
||||
IGNORE_PREFIX
|
||||
IMAGE_PATH
|
||||
INCLUDED_BY_GRAPH
|
||||
INCLUDE_FILE_PATTERNS
|
||||
INCLUDE_GRAPH
|
||||
INCLUDE_PATH
|
||||
INHERIT_DOCS
|
||||
INLINE_GROUPED_CLASSES
|
||||
INLINE_INFO
|
||||
INLINE_INHERITED_MEMB
|
||||
INLINE_SIMPLE_STRUCTS
|
||||
INLINE_SOURCES
|
||||
INPUT
|
||||
INPUT_ENCODING
|
||||
INPUT_FILTER
|
||||
INTERACTIVE_SVG
|
||||
INTERNAL_DOCS
|
||||
JAVADOC_AUTOBRIEF
|
||||
LATEX_BATCHMODE
|
||||
LATEX_BIB_STYLE
|
||||
LATEX_CMD_NAME
|
||||
LATEX_EXTRA_FILES
|
||||
LATEX_EXTRA_STYLESHEET
|
||||
LATEX_FOOTER
|
||||
LATEX_HEADER
|
||||
LATEX_HIDE_INDICES
|
||||
LATEX_OUTPUT
|
||||
LATEX_SOURCE_CODE
|
||||
LATEX_TIMESTAMP
|
||||
LAYOUT_FILE
|
||||
LOOKUP_CACHE_SIZE
|
||||
MACRO_EXPANSION
|
||||
MAKEINDEX_CMD_NAME
|
||||
MAN_EXTENSION
|
||||
MAN_LINKS
|
||||
MAN_OUTPUT
|
||||
MAN_SUBDIR
|
||||
MARKDOWN_SUPPORT
|
||||
MATHJAX_CODEFILE
|
||||
MATHJAX_EXTENSIONS
|
||||
MATHJAX_FORMAT
|
||||
MATHJAX_RELPATH
|
||||
MAX_DOT_GRAPH_DEPTH
|
||||
MAX_INITIALIZER_LINES
|
||||
MSCFILE_DIRS
|
||||
MSCGEN_PATH
|
||||
MULTILINE_CPP_IS_BRIEF
|
||||
OPTIMIZE_FOR_FORTRAN
|
||||
OPTIMIZE_OUTPUT_FOR_C
|
||||
OPTIMIZE_OUTPUT_JAVA
|
||||
OPTIMIZE_OUTPUT_VHDL
|
||||
OUTPUT_DIRECTORY
|
||||
OUTPUT_LANGUAGE
|
||||
PAPER_TYPE
|
||||
PDF_HYPERLINKS
|
||||
PERLMOD_LATEX
|
||||
PERLMOD_MAKEVAR_PREFIX
|
||||
PERLMOD_PRETTY
|
||||
PERL_PATH
|
||||
PLANTUML_CFG_FILE
|
||||
PLANTUML_INCLUDE_PATH
|
||||
PLANTUML_JAR_PATH
|
||||
PREDEFINED
|
||||
PROJECT_BRIEF
|
||||
PROJECT_LOGO
|
||||
PROJECT_NAME
|
||||
PROJECT_NUMBER
|
||||
QCH_FILE
|
||||
QHG_LOCATION
|
||||
QHP_CUST_FILTER_ATTRS
|
||||
QHP_CUST_FILTER_NAME
|
||||
QHP_NAMESPACE
|
||||
QHP_SECT_FILTER_ATTRS
|
||||
QHP_VIRTUAL_FOLDER
|
||||
QT_AUTOBRIEF
|
||||
QUIET
|
||||
RECURSIVE
|
||||
REFERENCED_BY_RELATION
|
||||
REFERENCES_LINK_SOURCE
|
||||
REFERENCES_RELATION
|
||||
REPEAT_BRIEF
|
||||
RTF_EXTENSIONS_FILE
|
||||
RTF_HYPERLINKS
|
||||
RTF_OUTPUT
|
||||
RTF_SOURCE_CODE
|
||||
RTF_STYLESHEET_FILE
|
||||
SEARCHDATA_FILE
|
||||
SEARCHENGINE
|
||||
SEARCHENGINE_URL
|
||||
SEARCH_INCLUDES
|
||||
SEPARATE_MEMBER_PAGES
|
||||
SERVER_BASED_SEARCH
|
||||
SHORT_NAMES
|
||||
SHOW_FILES
|
||||
SHOW_GROUPED_MEMB_INC
|
||||
SHOW_INCLUDE_FILES
|
||||
SHOW_NAMESPACES
|
||||
SHOW_USED_FILES
|
||||
SIP_SUPPORT
|
||||
SKIP_FUNCTION_MACROS
|
||||
SORT_BRIEF_DOCS
|
||||
SORT_BY_SCOPE_NAME
|
||||
SORT_GROUP_NAMES
|
||||
SORT_MEMBERS_CTORS_1ST
|
||||
SORT_MEMBER_DOCS
|
||||
SOURCE_BROWSER
|
||||
SOURCE_TOOLTIPS
|
||||
STRICT_PROTO_MATCHING
|
||||
STRIP_CODE_COMMENTS
|
||||
STRIP_FROM_INC_PATH
|
||||
STRIP_FROM_PATH
|
||||
SUBGROUPING
|
||||
TAB_SIZE
|
||||
TAGFILES
|
||||
TCL_SUBST
|
||||
TEMPLATE_RELATIONS
|
||||
TOC_EXPAND
|
||||
TOC_INCLUDE_HEADINGS
|
||||
TREEVIEW_WIDTH
|
||||
TYPEDEF_HIDES_STRUCT
|
||||
UML_LIMIT_NUM_FIELDS
|
||||
UML_LOOK
|
||||
USE_HTAGS
|
||||
USE_MATHJAX
|
||||
USE_MDFILE_AS_MAINPAGE
|
||||
USE_PDFLATEX
|
||||
VERBATIM_HEADERS
|
||||
WARNINGS
|
||||
WARN_AS_ERROR
|
||||
WARN_FORMAT
|
||||
WARN_IF_DOC_ERROR
|
||||
WARN_IF_UNDOCUMENTED
|
||||
WARN_LOGFILE
|
||||
WARN_NO_PARAMDOC
|
||||
XML_OUTPUT
|
||||
XML_PROGRAMLISTING
|
||||
)
|
||||
|
||||
set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file")
|
||||
|
||||
function(add_doxygen_doc)
|
||||
set(options)
|
||||
set(oneValueArgs)
|
||||
set(multiValueArgs DEPENDS ${DOXYGEN_ARGS})
|
||||
|
||||
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n")
|
||||
|
||||
foreach(ARG ${DOXYGEN_ARGS})
|
||||
if(PARSE_${ARG})
|
||||
string(REPLACE ";" " " ARG_VALUE ${PARSE_${ARG}})
|
||||
file(APPEND ${DOXYGEN_CONFIG_FILE} "\n${ARG} = ${ARG_VALUE}\n")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(PARSE_OUTPUT_DIRECTORY)
|
||||
if(NOT EXISTS ${PARSE_OUTPUT_DIRECTORY})
|
||||
file(MAKE_DIRECTORY ${PARSE_OUTPUT_DIRECTORY})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DOT_EXECUTABLE)
|
||||
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nDOT_PATH = \"${DOT_EXECUTABLE}\"\n")
|
||||
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = YES\n")
|
||||
else()
|
||||
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = NO\n")
|
||||
endif()
|
||||
|
||||
add_custom_target(doxygen
|
||||
${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMENT "Building documentation with doxygen"
|
||||
)
|
||||
if(PARSE_OUTPUT_DIRECTORY)
|
||||
clean_doc_output(${PARSE_OUTPUT_DIRECTORY})
|
||||
endif()
|
||||
mark_as_doc(doxygen)
|
||||
if(PARSE_DEPENDS)
|
||||
add_dependencies(doxygen ${PARSE_DEPENDS})
|
||||
endif()
|
||||
endfunction()
|
||||
110
cmake/EnableCompilerWarnings.cmake
Normal file
110
cmake/EnableCompilerWarnings.cmake
Normal file
@@ -0,0 +1,110 @@
|
||||
################################################################################
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2017 Advanced Micro Devices, Inc.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
################################################################################
|
||||
# - Enable warning all for gcc/clang or use /W4 for visual studio
|
||||
|
||||
## Strict warning level
|
||||
if (MSVC)
|
||||
# Use the highest warning level for visual studio.
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w")
|
||||
# set(CMAKE_CXX_WARNING_LEVEL 4)
|
||||
# if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]")
|
||||
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
# else ()
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
|
||||
# endif ()
|
||||
|
||||
# set(CMAKE_C_WARNING_LEVEL 4)
|
||||
# if (CMAKE_C_FLAGS MATCHES "/W[0-4]")
|
||||
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
|
||||
# else ()
|
||||
# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4")
|
||||
# endif ()
|
||||
|
||||
else()
|
||||
foreach(COMPILER C CXX)
|
||||
set(CMAKE_COMPILER_WARNINGS)
|
||||
# use -Wall for gcc and clang
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Wall
|
||||
-Wextra
|
||||
-Wcomment
|
||||
-Wendif-labels
|
||||
-Wformat
|
||||
-Winit-self
|
||||
-Wreturn-type
|
||||
-Wsequence-point
|
||||
# Shadow is broken on gcc when using lambdas
|
||||
# -Wshadow
|
||||
-Wswitch
|
||||
-Wtrigraphs
|
||||
-Wundef
|
||||
-Wuninitialized
|
||||
-Wunreachable-code
|
||||
-Wunused
|
||||
|
||||
-Wno-sign-compare
|
||||
-Wno-extra-semi-stmt
|
||||
)
|
||||
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Weverything
|
||||
-Wno-c++98-compat
|
||||
-Wno-c++98-compat-pedantic
|
||||
-Wno-conversion
|
||||
-Wno-double-promotion
|
||||
-Wno-exit-time-destructors
|
||||
-Wno-extra-semi
|
||||
-Wno-float-conversion
|
||||
-Wno-gnu-anonymous-struct
|
||||
-Wno-gnu-zero-variadic-macro-arguments
|
||||
-Wno-missing-prototypes
|
||||
-Wno-nested-anon-types
|
||||
-Wno-padded
|
||||
-Wno-return-std-move-in-c++11
|
||||
-Wno-shorten-64-to-32
|
||||
-Wno-sign-conversion
|
||||
-Wno-unknown-warning-option
|
||||
-Wno-unused-command-line-argument
|
||||
-Wno-weak-vtables
|
||||
-Wno-covered-switch-default
|
||||
)
|
||||
else()
|
||||
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")
|
||||
# cmake 3.5.2 does not support >=.
|
||||
if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.1")
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Wno-ignored-attributes)
|
||||
endif()
|
||||
endif()
|
||||
list(APPEND CMAKE_COMPILER_WARNINGS
|
||||
-Wno-missing-field-initializers
|
||||
-Wno-deprecated-declarations
|
||||
)
|
||||
endif()
|
||||
add_definitions(${CMAKE_COMPILER_WARNINGS})
|
||||
endforeach()
|
||||
endif ()
|
||||
14
composable_kernel/include/gridwise_operation_wrapper.hpp
Normal file
14
composable_kernel/include/gridwise_operation_wrapper.hpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
|
||||
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
|
||||
|
||||
template <typename GridwiseOp, typename... Xs>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
run_gridwise_operation(Xs... xs)
|
||||
{
|
||||
GridwiseOp{}.Run(xs...);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,272 @@
|
||||
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Number of GEMMs = YTilda * XTilda
|
||||
// GemmM = C
|
||||
// GemmN = N * HTildaSlice * WTildaSlice
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t IYTildaValue,
|
||||
index_t IXTildaValue,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<IYTildaValue>,
|
||||
Number<IXTildaValue>,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
constexpr auto IYTilda = Number<IYTildaValue>{};
|
||||
constexpr auto IXTilda = Number<IXTildaValue>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
const auto IHTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
|
||||
const auto IWTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildaSliceEnd =
|
||||
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildaSliceEnd =
|
||||
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
|
||||
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilda),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilda),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
#if 1
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// output tensor
|
||||
// this add padding check
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
make_pad_transform(Wo, I0, I0),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilda),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilda),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
#if 1
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilda, WTilda),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(C),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,275 @@
|
||||
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: out
|
||||
// B: wei
|
||||
// C: in
|
||||
// Number of GEMMs = YTilda * XTilda
|
||||
// GemmM = N * HTildaSlice * WTildaSlice
|
||||
// GemmN = C
|
||||
// GemmK = K * YDotSlice * XDotSlice
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t IYTildaValue,
|
||||
index_t IXTildaValue,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<IYTildaValue>,
|
||||
Number<IXTildaValue>,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
constexpr auto IYTilda = Number<IYTildaValue>{};
|
||||
constexpr auto IXTilda = Number<IXTildaValue>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
const auto YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
const auto XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
|
||||
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
|
||||
|
||||
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
|
||||
const auto IHTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
|
||||
const auto IWTildaSliceBegin = math::integer_divide_floor(
|
||||
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
|
||||
|
||||
const auto IHTildaSliceEnd =
|
||||
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
|
||||
const auto IWTildaSliceEnd =
|
||||
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
|
||||
|
||||
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
|
||||
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
|
||||
|
||||
// GemmK is different for each GEMM
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
|
||||
|
||||
const auto K1 = GemmK1;
|
||||
const auto K0 = K / K1;
|
||||
|
||||
// A: output tensor
|
||||
// this add padding check
|
||||
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ho_wo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, I0, I0),
|
||||
make_pad_transform(Wo, I0, I0),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
|
||||
out_n_hop_wop_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YDot, HTilda),
|
||||
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, WTilda),
|
||||
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_unmerge_transform(make_tuple(K0, K1))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6>{}));
|
||||
|
||||
#if 1
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_y_x_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_embed_transform(make_tuple(YDot, YTilda),
|
||||
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
|
||||
make_embed_transform(make_tuple(XDot, XTilda),
|
||||
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
|
||||
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
|
||||
make_slice_transform(YDot, I0, YDotSlice),
|
||||
make_slice_transform(XDot, I0, XDotSlice),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<3>{},
|
||||
Sequence<2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0, 1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<>{},
|
||||
Sequence<>{},
|
||||
Sequence<4>{}));
|
||||
|
||||
#if 1
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#else
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
|
||||
make_pass_through_transform(C),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
#endif
|
||||
|
||||
// C: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(YTilda, HTilda),
|
||||
make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(XTilda, WTilda),
|
||||
make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_freeze_transform(IYTilda),
|
||||
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
|
||||
make_freeze_transform(IXTilda),
|
||||
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<>{},
|
||||
Sequence<1>{},
|
||||
Sequence<>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_htildaslice_wtildaslice_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,257 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_global_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
|
||||
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
|
||||
InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,176 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
|
||||
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
|
||||
InRightPadW == 0);
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,129 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = N * Ho * Wo;
|
||||
const auto GemmK = C * Y * X;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,129 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = K;
|
||||
const auto GemmN = N * Ho * Wo;
|
||||
const auto GemmK = C * Y * X;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,132 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A: in
|
||||
// B: wei
|
||||
// C: out
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = K
|
||||
// GemmK = Y * X * C
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t GemmK1Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
|
||||
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<GemmK1Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto GemmK1 = Number<GemmK1Value>{};
|
||||
|
||||
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
|
||||
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
|
||||
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
|
||||
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
|
||||
|
||||
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
|
||||
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto GemmM = N * Ho * Wo;
|
||||
const auto GemmN = K;
|
||||
const auto GemmK = Y * X * C;
|
||||
const auto GemmK0 = GemmK / GemmK1;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,132 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM0 = 1
|
||||
// GemmM1 = K
|
||||
// GemmN0 = N0
|
||||
// GemmN1 = (N / N0) * Ho * Wo
|
||||
// GemmK0 = (C / C0) * Y * X
|
||||
// GemmK1 = C0
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
typename N0Type,
|
||||
typename C0Type>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const N0Type& N0,
|
||||
const C0Type& C0)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
const auto N1 = N / N0;
|
||||
const auto C1 = C / C0;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gk0_gm0_gm1_gk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(C0, C1)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
|
||||
make_pass_through_transform(N0),
|
||||
make_merge_transform(make_tuple(N1, Ho, Wo)),
|
||||
make_pass_through_transform(C0)),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_n_k_howo_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo));
|
||||
|
||||
const auto out_n0_n1_1_k_howo_grid_desc =
|
||||
transform_tensor_descriptor(out_n_k_howo_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(Ho * Wo)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_1_k_howo_grid_desc,
|
||||
make_tuple(make_pass_through_transform(I1),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(N0),
|
||||
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
|
||||
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,33 @@
|
||||
#ifndef CK_CLUSTER_DESCRIPTOR_HPP
|
||||
#define CK_CLUSTER_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename Lengths,
|
||||
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
|
||||
__host__ __device__ constexpr auto make_cluster_descriptor(
|
||||
const Lengths& lengths,
|
||||
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
|
||||
{
|
||||
constexpr index_t ndim_low = Lengths::Size();
|
||||
|
||||
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
|
||||
|
||||
const auto low_lengths = generate_tuple(
|
||||
[&](auto idim_low) { return reordered_lengths[idim_low]; }, Number<ndim_low>{});
|
||||
|
||||
const auto transform = make_merge_transform(low_lengths);
|
||||
|
||||
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
|
||||
|
||||
constexpr auto up_dim_new_top_ids = Sequence<0>{};
|
||||
|
||||
return make_single_stage_tensor_adaptor(
|
||||
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,103 @@
|
||||
#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
||||
#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename LowLength>
|
||||
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
|
||||
{
|
||||
return PassThrough<LowLength>{low_length};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
|
||||
__host__ __device__ constexpr auto
|
||||
make_pad_transform(const LowLength& low_length,
|
||||
const LeftPad& left_pad,
|
||||
const RightPad& right_pad,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
|
||||
__host__ __device__ constexpr auto make_left_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const LeftPadLength& left_pad,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
|
||||
__host__ __device__ constexpr auto make_right_pad_transform(
|
||||
const LowLength& low_length,
|
||||
const RightPadLength& right_pad,
|
||||
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
|
||||
{
|
||||
return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
|
||||
}
|
||||
|
||||
template <typename UpLengths,
|
||||
typename Coefficients,
|
||||
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
|
||||
const Coefficients& coefficients)
|
||||
{
|
||||
return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
|
||||
{
|
||||
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
|
||||
return Merge_v1_carry_check<LowLengths>{low_lengths};
|
||||
#else
|
||||
#if 1
|
||||
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||
#else
|
||||
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
|
||||
{
|
||||
return Merge_v2_magic_division<LowLengths>{low_lengths};
|
||||
}
|
||||
|
||||
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
|
||||
__host__ __device__ constexpr auto make_unmerge_transform(
|
||||
const UpLengths& up_lengths,
|
||||
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
|
||||
{
|
||||
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
|
||||
}
|
||||
|
||||
template <typename LowerIndex>
|
||||
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
|
||||
{
|
||||
return Freeze<LowerIndex>{low_idx};
|
||||
}
|
||||
|
||||
template <typename LowLength, typename SliceBegin, typename SliceEnd>
|
||||
__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length,
|
||||
const SliceBegin& slice_begin,
|
||||
const SliceEnd& slice_end)
|
||||
{
|
||||
return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
|
||||
}
|
||||
|
||||
template <typename VectorSize, typename UpLength>
|
||||
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
|
||||
const UpLength& up_length)
|
||||
{
|
||||
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
464
composable_kernel/include/tensor_description/tensor_adaptor.hpp
Normal file
464
composable_kernel/include/tensor_description/tensor_adaptor.hpp
Normal file
@@ -0,0 +1,464 @@
|
||||
#ifndef CK_TENSOR_ADAPTOR_HPP
|
||||
#define CK_TENSOR_ADAPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
|
||||
// BottomDimensionHiddenIds : Sequence<...>
|
||||
// TopDimensionHiddenIds : Sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionHiddenIdss,
|
||||
typename UpperDimensionHiddenIdss,
|
||||
typename BottomDimensionHiddenIds,
|
||||
typename TopDimensionHiddenIds>
|
||||
struct TensorAdaptor
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
|
||||
|
||||
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss()
|
||||
{
|
||||
return LowerDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss()
|
||||
{
|
||||
return UpperDimensionHiddenIdss{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
|
||||
{
|
||||
return TopDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBottomDimensionHiddenIds()
|
||||
{
|
||||
return BottomDimensionHiddenIds{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
|
||||
{
|
||||
const auto lengths = generate_tuple(
|
||||
[&](auto idim_top) {
|
||||
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);
|
||||
|
||||
constexpr index_t itran = tmp[Number<0>{}];
|
||||
constexpr index_t idim_up = tmp[Number<1>{}];
|
||||
constexpr bool found = tmp[Number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
const auto length =
|
||||
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
|
||||
|
||||
return length;
|
||||
},
|
||||
Number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of Number and index_t
|
||||
return container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
|
||||
{
|
||||
constexpr auto idim_top = Number<IDim>{};
|
||||
|
||||
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == idim_hidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfBottomDimension()
|
||||
{
|
||||
return BottomDimensionHiddenIds::Size();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfTopDimension()
|
||||
{
|
||||
return TopDimensionHiddenIds::Size();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
LowerDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
|
||||
UpperDimensionHiddenIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>::type;
|
||||
|
||||
return unique_sort_all_dim_ids::Size();
|
||||
}
|
||||
|
||||
constexpr static index_t ntransform_ = GetNumOfTransform();
|
||||
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
|
||||
constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension();
|
||||
constexpr static index_t ndim_top_ = GetNumOfTopDimension();
|
||||
|
||||
using HiddenIndex = MultiIndex<ndim_hidden_>;
|
||||
using BottomIndex = MultiIndex<ndim_bottom_>;
|
||||
using TopIndex = MultiIndex<ndim_top_>;
|
||||
|
||||
// may be index_t or Number<>
|
||||
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr TensorAdaptor() = default;
|
||||
|
||||
__host__ __device__ constexpr TensorAdaptor(const Transforms& transforms)
|
||||
: transforms_{transforms}, element_size_{InitializeElementSize(transforms)}
|
||||
{
|
||||
static_assert(Transforms::Size() == ntransform_ &&
|
||||
LowerDimensionHiddenIdss::Size() == ntransform_ &&
|
||||
UpperDimensionHiddenIdss::Size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
|
||||
|
||||
template <typename TopIdx>
|
||||
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
|
||||
{
|
||||
static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = GetNumOfTransform();
|
||||
constexpr index_t ndim_hidden = GetNumOfHiddenDimension();
|
||||
|
||||
MultiIndex<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize uppest index
|
||||
set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
|
||||
auto itran = itran_p1 - Number<1>{};
|
||||
const auto& tran = GetTransforms().At(itran);
|
||||
constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran);
|
||||
constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
MultiIndex<dims_low.Size()> idx_low;
|
||||
|
||||
tran.CalculateLowerIndex(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
bool is_known = true;
|
||||
|
||||
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
|
||||
is_known &=
|
||||
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
|
||||
});
|
||||
|
||||
return is_known && is_known_at_compile_time<ElementSize>::value;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("TensorAdaptor, ");
|
||||
static_for<0, ntransform_, 1>{}([&](auto i) {
|
||||
printf("transforms: ");
|
||||
transforms_[i].Print();
|
||||
printf("LowerDimensionHiddenIds:");
|
||||
LowerDimensionHiddenIdss{}.At(i).Print();
|
||||
printf("UpperDimensionHiddenIds:");
|
||||
UpperDimensionHiddenIdss{}.At(i).Print();
|
||||
});
|
||||
|
||||
printf("BottomDimensionHiddenIds:");
|
||||
BottomDimensionHiddenIds::Print();
|
||||
printf("TopDimensionHiddenIds:");
|
||||
TopDimensionHiddenIds::Print();
|
||||
|
||||
printf("}");
|
||||
}
|
||||
|
||||
private:
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
};
|
||||
|
||||
template <typename TensorAdaptor0, typename TensorAdaptor1>
|
||||
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
|
||||
const TensorAdaptor1& adaptor1)
|
||||
{
|
||||
static_assert(TensorAdaptor0::GetNumOfTopDimension() ==
|
||||
TensorAdaptor1::GetNumOfBottomDimension(),
|
||||
"wrong!");
|
||||
|
||||
// all_transforms = transform0 + transform1
|
||||
const auto all_transforms =
|
||||
container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms());
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id_ = NumericLimits<index_t>::Min();
|
||||
|
||||
static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
math::max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension();
|
||||
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor0_max_hidden_id_ =
|
||||
math::max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor0_max_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id_ = NumericLimits<index_t>::Max();
|
||||
|
||||
static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
|
||||
|
||||
// get the min of all lower dimenions, but not bottom dimension (because their id will
|
||||
// be matched with top id from adaptor0)
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
constexpr index_t low_dim_hidden_id =
|
||||
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value;
|
||||
|
||||
bool is_bottom_dim = false;
|
||||
static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) {
|
||||
if constexpr(low_dim_hidden_id ==
|
||||
TensorAdaptor1::GetBottomDimensionHiddenIds()[i])
|
||||
{
|
||||
is_bottom_dim = true;
|
||||
}
|
||||
});
|
||||
|
||||
if(!is_bottom_dim)
|
||||
{
|
||||
adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id);
|
||||
}
|
||||
});
|
||||
|
||||
constexpr index_t ndim_up =
|
||||
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
|
||||
|
||||
// get the min of all upper dimensions
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor1_min_hidden_id_ =
|
||||
math::min(adaptor1_min_hidden_id_,
|
||||
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor1_min_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_hidden_id_shift =
|
||||
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
|
||||
|
||||
constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
|
||||
|
||||
// all_low_dim_hidden_idss =
|
||||
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
|
||||
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size();
|
||||
|
||||
constexpr auto low_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
|
||||
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
// match hidden id
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
|
||||
// if this low dim is bottom dim, then do id matching
|
||||
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
|
||||
TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1])
|
||||
{
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) =
|
||||
TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; },
|
||||
Number<ndim_low_1>{});
|
||||
},
|
||||
Number<TensorAdaptor1::GetNumOfTransform()>{});
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1);
|
||||
|
||||
// all_up_dim_hidden_idss =
|
||||
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
|
||||
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
|
||||
// generate sequence of ids for a transform
|
||||
[&](auto itran) {
|
||||
constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size();
|
||||
|
||||
constexpr auto up_dim_hidden_ids_1 =
|
||||
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
|
||||
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
|
||||
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
// constexpr tuple to sequence
|
||||
return generate_sequence_v2(
|
||||
[&](auto i) constexpr { return Number<up_dim_hidden_ids_1_mod[i]>{}; },
|
||||
Number<ndim_up_1>{});
|
||||
},
|
||||
Number<TensorAdaptor1::GetNumOfTransform()>{});
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1);
|
||||
|
||||
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
|
||||
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds();
|
||||
|
||||
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
|
||||
|
||||
// put everything together
|
||||
return TensorAdaptor<remove_cv_t<decltype(all_transforms)>,
|
||||
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms};
|
||||
}
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
|
||||
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
|
||||
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
|
||||
LowerDimensionOldTopIdss,
|
||||
UpperDimensionNewTopIdss)
|
||||
{
|
||||
constexpr index_t ntransform = Transforms::Size();
|
||||
|
||||
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
|
||||
UpperDimensionNewTopIdss::Size() == ntransform,
|
||||
"wrong!");
|
||||
|
||||
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
|
||||
constexpr auto all_low_dim_old_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
|
||||
|
||||
constexpr auto all_up_dim_new_top_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
|
||||
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
|
||||
|
||||
// low_dim_hidden_idss
|
||||
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
|
||||
|
||||
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
|
||||
Number<ntransform>{});
|
||||
|
||||
// bottom_dim_hidden_ids
|
||||
constexpr auto bottom_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
|
||||
|
||||
// top_dim_hidden_ids
|
||||
constexpr auto top_dim_hidden_ids =
|
||||
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
|
||||
|
||||
return TensorAdaptor<remove_cv_t<Transforms>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||
}
|
||||
|
||||
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,597 @@
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t NDimHidden, typename VisibleDimensionIds>
|
||||
struct TensorCoordinate;
|
||||
|
||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||
struct TensorCoordinateStep;
|
||||
|
||||
// Transforms: Tuple<transforms...>
|
||||
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
|
||||
// UpperDimensionIdss : Tuple<Sequence<...>, ...>
|
||||
// VisibleDimensionIds> : Sequence<...>
|
||||
template <typename Transforms,
|
||||
typename LowerDimensionIdss,
|
||||
typename UpperDimensionIdss,
|
||||
typename VisibleDimensionIds,
|
||||
typename ElementSpaceSize>
|
||||
struct TensorDescriptor
|
||||
{
|
||||
// TODO make these private
|
||||
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfVisibleDimension()
|
||||
{
|
||||
return VisibleDimensionIds::Size();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
|
||||
{
|
||||
constexpr auto all_low_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
|
||||
|
||||
constexpr auto all_up_dim_ids = unpack(
|
||||
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
|
||||
|
||||
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
|
||||
|
||||
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
|
||||
math::less<index_t>,
|
||||
math::equal<index_t>>::type;
|
||||
|
||||
return unique_sort_all_dim_ids::Size();
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
|
||||
{
|
||||
const auto lengths = generate_tuple(
|
||||
[&](auto idim_visible) {
|
||||
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible);
|
||||
|
||||
constexpr index_t itran = tmp[Number<0>{}];
|
||||
constexpr index_t idim_up = tmp[Number<1>{}];
|
||||
constexpr bool found = tmp[Number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
const auto length =
|
||||
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
|
||||
|
||||
return length;
|
||||
},
|
||||
Number<ndim_visible_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of Number and index_t
|
||||
return container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
|
||||
{
|
||||
constexpr auto idim_visible = Number<IDim>{};
|
||||
|
||||
constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible);
|
||||
|
||||
index_t itran_found = 0;
|
||||
index_t idim_up_found = 0;
|
||||
bool found = false;
|
||||
|
||||
static_for<0, ntransform_, 1>{}([&](auto itran) {
|
||||
constexpr auto up_dim_ids = UpperDimensionIdss{}[itran];
|
||||
|
||||
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
|
||||
if constexpr(up_dim_ids[idim_up] == idim_hidden)
|
||||
{
|
||||
itran_found = itran;
|
||||
idim_up_found = idim_up;
|
||||
found = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return make_tuple(itran_found, idim_up_found, found);
|
||||
}
|
||||
|
||||
constexpr static index_t ntransform_ = GetNumOfTransform();
|
||||
constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension();
|
||||
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
|
||||
|
||||
using VisibleIndex = MultiIndex<ndim_visible_>;
|
||||
using HiddenIndex = MultiIndex<ndim_hidden_>;
|
||||
using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
|
||||
|
||||
// may be index_t or Number<>
|
||||
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr TensorDescriptor() = default;
|
||||
|
||||
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
|
||||
ElementSpaceSize element_space_size)
|
||||
: transforms_{transforms},
|
||||
element_size_{InitializeElementSize(transforms)},
|
||||
element_space_size_{element_space_size}
|
||||
|
||||
{
|
||||
static_assert(Transforms::Size() == ntransform_ &&
|
||||
LowerDimensionIdss::Size() == ntransform_ &&
|
||||
UpperDimensionIdss::Size() == ntransform_,
|
||||
"wrong! inconsistent # of transformations");
|
||||
|
||||
// TODO check dependency of dimensions is valid
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetNumOfDimension()
|
||||
{
|
||||
return GetNumOfVisibleDimension();
|
||||
}
|
||||
|
||||
template <index_t IDim>
|
||||
__host__ __device__ constexpr auto GetLength(Number<IDim>) const
|
||||
{
|
||||
static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range");
|
||||
|
||||
constexpr auto tmp = GetTransformAndItsUpperDimension(Number<IDim>{});
|
||||
|
||||
constexpr index_t itran = tmp[Number<0>{}];
|
||||
constexpr index_t idim_up = tmp[Number<1>{}];
|
||||
constexpr bool found = tmp[Number<2>{}];
|
||||
|
||||
static_assert(found == true,
|
||||
"wrong! not found matching transformation and upper-dimension");
|
||||
|
||||
return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
|
||||
|
||||
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
|
||||
|
||||
template <typename Idx>
|
||||
__host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const
|
||||
{
|
||||
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
|
||||
|
||||
return make_tensor_coordinate(*this, idx).GetOffset();
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetLowerDimensionIdss()
|
||||
{
|
||||
return LowerDimensionIdss{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetUpperDimensionIdss()
|
||||
{
|
||||
return UpperDimensionIdss{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetVisibleDimensionIds()
|
||||
{
|
||||
return VisibleDimensionIds{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
bool is_known = true;
|
||||
|
||||
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
|
||||
is_known &=
|
||||
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
|
||||
});
|
||||
|
||||
return is_known && is_known_at_compile_time<ElementSize>::value &&
|
||||
is_known_at_compile_time<ElementSpaceSize>::value;
|
||||
}
|
||||
|
||||
__host__ __device__ void Print() const
|
||||
{
|
||||
printf("{");
|
||||
printf("TensorDescriptor, ");
|
||||
static_for<0, ntransform_, 1>{}([&](auto i) {
|
||||
printf("transforms: ");
|
||||
transforms_[i].Print();
|
||||
printf("LowerDimensionIds:");
|
||||
LowerDimensionIdss{}.At(i).Print();
|
||||
printf("UpperDimensionIds:");
|
||||
UpperDimensionIdss{}.At(i).Print();
|
||||
});
|
||||
printf("}");
|
||||
|
||||
VisibleDimensionIds::Print();
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
Transforms transforms_;
|
||||
ElementSize element_size_;
|
||||
ElementSpaceSize element_space_size_;
|
||||
};
|
||||
|
||||
template <index_t NDimHidden, typename VisibleDimensionIds>
|
||||
struct TensorCoordinate
|
||||
{
|
||||
// TODO make these private
|
||||
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
|
||||
|
||||
using HiddenIndex = MultiIndex<NDimHidden>;
|
||||
using VisibleIndex = MultiIndex<ndim_visible_>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr TensorCoordinate() = default;
|
||||
|
||||
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
|
||||
: idx_hidden_{idx_hidden}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); }
|
||||
|
||||
__host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; }
|
||||
|
||||
// TODO make these private
|
||||
__host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; }
|
||||
|
||||
__host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; }
|
||||
|
||||
__host__ __device__ constexpr auto GetVisibleIndex() const
|
||||
{
|
||||
return get_container_subset(idx_hidden_, VisibleDimensionIds{});
|
||||
}
|
||||
|
||||
// TODO make these private
|
||||
HiddenIndex idx_hidden_;
|
||||
};
|
||||
|
||||
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
|
||||
struct TensorCoordinateStep
|
||||
{
|
||||
// TODO make these private
|
||||
using VisibleIndex = MultiIndex<NDimVisible>;
|
||||
|
||||
public:
|
||||
__host__ __device__ constexpr TensorCoordinateStep() = default;
|
||||
|
||||
__host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
|
||||
const MultiIndex<NTransform>& do_transforms)
|
||||
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); }
|
||||
|
||||
// TODO make these private
|
||||
__host__ __device__ constexpr const auto& GetVisibleIndexDiff() const
|
||||
{
|
||||
return idx_diff_visible_;
|
||||
}
|
||||
|
||||
VisibleIndex idx_diff_visible_;
|
||||
MultiIndex<NTransform> do_transforms_;
|
||||
|
||||
// HACK: control UpdateLowerIndex()
|
||||
static constexpr UpdateLowerIndexHack update_lower_index_hack_;
|
||||
};
|
||||
|
||||
// TODO: How to fix this? It uses an struct instead of lambda because lambda
|
||||
// doesn't have constructor, and to put it outside the scope where it is used
|
||||
// (transform_tensor_descriptor) because template cannot be defined inside a function
|
||||
// template
|
||||
template <typename NewTransforms>
|
||||
struct lambda_get_up_dim_num
|
||||
{
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr auto operator()(I) const
|
||||
{
|
||||
using Tran = remove_reference_t<decltype(NewTransforms{}.At(I{}))>;
|
||||
return Number<Tran::GetNumOfUpperDimension()>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OldTensorDescriptor,
|
||||
typename NewTransforms,
|
||||
typename NewLowerDimensionOldVisibleIdss,
|
||||
typename NewUpperDimensionNewVisibleIdss>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
const NewTransforms& new_transforms,
|
||||
NewLowerDimensionOldVisibleIdss,
|
||||
NewUpperDimensionNewVisibleIdss)
|
||||
{
|
||||
// sanity check
|
||||
{
|
||||
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
|
||||
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// lower dimension's hidden idss
|
||||
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
|
||||
// sequences)
|
||||
constexpr auto low_dim_hidden_idss = transform_tuples(
|
||||
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
|
||||
[](auto low_dim_visible_ids) constexpr {
|
||||
return transform_sequences(
|
||||
// convert lower dimension visible id to hidden id
|
||||
[](auto low_dim_visible_id) constexpr {
|
||||
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
|
||||
},
|
||||
low_dim_visible_ids);
|
||||
},
|
||||
NewLowerDimensionOldVisibleIdss{});
|
||||
|
||||
constexpr index_t num_new_transform = NewTransforms::Size();
|
||||
|
||||
// upper dimension's hidden idss
|
||||
constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension();
|
||||
|
||||
constexpr auto up_dim_numbers =
|
||||
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, Number<num_new_transform>{});
|
||||
|
||||
constexpr auto up_dim_numbers_scan = merge_sequences(
|
||||
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
|
||||
|
||||
constexpr auto up_dim_hidden_idss = generate_tuple(
|
||||
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
|
||||
return
|
||||
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
|
||||
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
|
||||
1>::type{};
|
||||
},
|
||||
Number<num_new_transform>{});
|
||||
|
||||
// new visible dimension's hidden ids
|
||||
constexpr auto unordered_new_visible_dim_hidden_ids = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
|
||||
|
||||
constexpr auto new_visible_dim_unordered2ordered = unpack(
|
||||
[](auto... xs) constexpr { return merge_sequences(xs...); },
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
constexpr auto new_visible_dim_hidden_ids =
|
||||
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
|
||||
|
||||
// put everything together
|
||||
const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms);
|
||||
|
||||
constexpr auto all_low_dim_hidden_idss =
|
||||
container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
|
||||
|
||||
constexpr auto all_up_dim_hidden_idss =
|
||||
container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
|
||||
|
||||
const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
|
||||
|
||||
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
|
||||
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{all_transforms,
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename VisibleIndex>
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
const VisibleIndex& idx_visible)
|
||||
{
|
||||
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
|
||||
|
||||
MultiIndex<ndim_hidden> idx_hidden;
|
||||
|
||||
// initialize visible index
|
||||
set_container_subset(idx_hidden, visible_dim_ids, idx_visible);
|
||||
|
||||
// calculate hidden index
|
||||
static_for<ntransform, 0, -1>{}([&tensor_desc, &idx_hidden](auto itran_p1) {
|
||||
auto itran = itran_p1 - Number<1>{};
|
||||
const auto& tran = tensor_desc.GetTransforms().At(itran);
|
||||
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
||||
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
|
||||
|
||||
const auto idx_up = get_container_subset(idx_hidden, dims_up);
|
||||
|
||||
MultiIndex<dims_low.Size()> idx_low;
|
||||
|
||||
tran.CalculateLowerIndex(idx_low, idx_up);
|
||||
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
});
|
||||
|
||||
return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
|
||||
}
|
||||
|
||||
// UpdateLowerIndexHack: Sequence<...>
|
||||
// HACK: control UpdateLowerIndex
|
||||
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||
const VisibleIndex& idx_diff_visible,
|
||||
UpdateLowerIndexHack)
|
||||
{
|
||||
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
|
||||
"wrong! # of dimension inconsistent");
|
||||
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
|
||||
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
|
||||
|
||||
static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!");
|
||||
|
||||
// use index_t for boolean type
|
||||
auto do_transforms = make_zero_multi_index<ntransform>();
|
||||
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// decide do_transform by checkout non-zero index diff components
|
||||
MultiIndex<VisibleIndex::Size()> non_zero_diff_pick_visible;
|
||||
|
||||
static_for<0, ndim_visible, 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); });
|
||||
|
||||
set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible);
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
||||
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
|
||||
|
||||
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
|
||||
|
||||
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
|
||||
|
||||
// if any of upper index diff components is non-zero, then
|
||||
// 1) Need to do this transform
|
||||
// 2) all components of lower index diff will assume to be non-zero and need to be
|
||||
// computed
|
||||
const bool idx_diff_up_has_non_zero = container_reduce(
|
||||
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
|
||||
|
||||
do_transforms(itran) = idx_diff_up_has_non_zero;
|
||||
|
||||
static_for<0, dims_low.Size(), 1>{}(
|
||||
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
|
||||
|
||||
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
|
||||
});
|
||||
|
||||
return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
|
||||
do_transforms};
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename VisibleIndex>
|
||||
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
|
||||
const VisibleIndex& idx_diff_visible)
|
||||
{
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
|
||||
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
|
||||
TensorCoord& coord,
|
||||
const TensorCoordStep& coord_step)
|
||||
{
|
||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
|
||||
// this is what needs to be calculated
|
||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
// initialize visible index diff
|
||||
set_container_subset(
|
||||
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
|
||||
|
||||
// this is what needs to be updated
|
||||
auto& idx_hidden = coord.GetHiddenIndex();
|
||||
|
||||
// update visible index
|
||||
auto idx_hidden_pick_visible =
|
||||
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
|
||||
|
||||
idx_hidden_pick_visible += coord_step.GetIndexDiff();
|
||||
|
||||
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
|
||||
|
||||
// update rest of hidden index
|
||||
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
|
||||
if(coord_step.do_transforms_[itran])
|
||||
{
|
||||
const auto& tran = tensor_desc.GetTransforms().At(itran);
|
||||
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
|
||||
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
|
||||
|
||||
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
|
||||
auto idx_low = get_container_subset(idx_hidden, dims_low);
|
||||
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
|
||||
|
||||
MultiIndex<dims_low.Size()> idx_diff_low;
|
||||
|
||||
// HACK: control UpdateLowerIndex for Merge using hack
|
||||
constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
|
||||
|
||||
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
|
||||
|
||||
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
|
||||
set_container_subset(idx_hidden, dims_low, idx_low);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
__host__ __device__ constexpr bool
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
bool valid = true;
|
||||
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
|
||||
const auto& idx_hidden = coord.GetHiddenIndex();
|
||||
|
||||
static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &valid](auto itran) {
|
||||
const auto tran = tensor_desc.GetTransforms().At(itran);
|
||||
|
||||
// check validity, only if current transformation does not always has a valid mapping
|
||||
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
|
||||
{
|
||||
const auto idx_up =
|
||||
get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
|
||||
|
||||
// Comment: using valid = valid && .. will result in weird control flow in ISA
|
||||
valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
|
||||
}
|
||||
});
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
template <typename TensorDesc, typename TensorCoord>
|
||||
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
|
||||
const TensorCoord& coord)
|
||||
{
|
||||
// check visible index
|
||||
const auto& idx_visible = coord.GetVisibleIndex();
|
||||
|
||||
bool is_visible_index_valid = true;
|
||||
|
||||
static_for<0, TensorDesc::GetNumOfDimension(), 1>{}(
|
||||
[&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) {
|
||||
is_visible_index_valid =
|
||||
is_visible_index_valid &&
|
||||
(idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i));
|
||||
});
|
||||
|
||||
// check other hidden index
|
||||
return is_visible_index_valid &&
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord);
|
||||
}
|
||||
|
||||
template <typename TensorDesc>
|
||||
using TensorCoordinate_t = decltype(make_tensor_coordinate(
|
||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||
|
||||
template <typename TensorDesc>
|
||||
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
|
||||
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,149 @@
|
||||
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/*
|
||||
* These functions create tensor descriptor at runtime. If they are not constexpr, you will
|
||||
* likely see usage of scratch memory during construction of these tensor descriptors. So
|
||||
* it's better to call these functions on host and then pass the constructed tensor descritpors
|
||||
* to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
|
||||
* functions on GPU without worrying about scratch memory usage.
|
||||
*/
|
||||
|
||||
#if CK_WORKAROUND_SWDEV_275126
|
||||
template <typename Lengths, typename Strides, index_t I, typename AccOld>
|
||||
__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
|
||||
const Strides& strides,
|
||||
Number<I> i,
|
||||
AccOld acc_old)
|
||||
{
|
||||
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
|
||||
|
||||
if constexpr(i.value < Lengths::Size() - 1)
|
||||
{
|
||||
return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return acc_new;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths,
|
||||
const Tuple<Strides...>& strides)
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_275126
|
||||
// rocm-4.1 compiler would crash for recursive labmda
|
||||
// recursive function for reduction
|
||||
auto f = [&](auto fs, auto i, auto acc_old) {
|
||||
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
|
||||
|
||||
if constexpr(i.value < N - 1)
|
||||
{
|
||||
return fs(fs, i + Number<1>{}, acc_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return acc_new;
|
||||
}
|
||||
};
|
||||
|
||||
const auto element_space_size = f(f, Number<0>{}, Number<1>{});
|
||||
#else
|
||||
const auto element_space_size =
|
||||
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
|
||||
#endif
|
||||
|
||||
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{transforms,
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
// Lengths... can be:
|
||||
// 1) index_t, which is known at run-time
|
||||
// 2) Number<>, which is known at compile-time
|
||||
template <typename... Lengths>
|
||||
__host__ __device__ constexpr auto
|
||||
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
|
||||
{
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto transforms = make_tuple(make_unmerge_transform(lengths));
|
||||
|
||||
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
|
||||
|
||||
constexpr auto up_dim_hidden_idss =
|
||||
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
|
||||
|
||||
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
|
||||
remove_cv_t<decltype(low_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(up_dim_hidden_idss)>,
|
||||
remove_cv_t<decltype(visible_dim_hidden_ids)>,
|
||||
remove_cv_t<decltype(element_space_size)>>{transforms,
|
||||
element_space_size};
|
||||
}
|
||||
|
||||
template <typename... Lengths, typename Align>
|
||||
__host__ __device__ constexpr auto
|
||||
make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t N = sizeof...(Lengths);
|
||||
|
||||
const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number<N - 1>{}], align);
|
||||
|
||||
auto strides = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i.value == N - 1)
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
else if constexpr(i.value == N - 2)
|
||||
{
|
||||
return Number<stride_n_minus_2>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return container_reduce(lengths,
|
||||
math::multiplies{},
|
||||
Number<stride_n_minus_2>{},
|
||||
i + I1,
|
||||
Number<N - 1>{},
|
||||
I1);
|
||||
}
|
||||
},
|
||||
Number<N>{});
|
||||
|
||||
return make_naive_tensor_descriptor(lengths, strides);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,394 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. AKMBlockDesc is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BKNBlockDesc is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CM0M1N0N1ThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <
|
||||
index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AKMBlockDesc,
|
||||
typename BKNBlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t M100 = M1N1ThreadClusterM100;
|
||||
static constexpr index_t N100 = M1N1ThreadClusterN100;
|
||||
|
||||
static constexpr index_t M101 = M1N1ThreadClusterM101;
|
||||
static constexpr index_t N101 = M1N1ThreadClusterN101;
|
||||
|
||||
static constexpr index_t M11 = M1PerThreadM11;
|
||||
static constexpr index_t N11 = N1PerThreadN11;
|
||||
|
||||
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
|
||||
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
|
||||
|
||||
static constexpr index_t M0 = M / M1;
|
||||
static constexpr index_t N0 = N / N1;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
|
||||
{
|
||||
const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
|
||||
AKMBlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return a_k_m0_m1_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
|
||||
{
|
||||
const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
|
||||
BKNBlockDesc{},
|
||||
make_tuple(make_pass_through_transform(Number<K>{}),
|
||||
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return b_k_n0_n1_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M, N]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M0, M1, N0, N1]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<M0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
|
||||
{
|
||||
return Sequence<M0, M11, N0, N11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
|
||||
static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
|
||||
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
|
||||
{
|
||||
static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == M101 * M100 * N101 * N100,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
|
||||
|
||||
static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2, "wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
|
||||
{
|
||||
// lower: [M0, M1, N0, N1]
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
|
||||
|
||||
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// upper: [Tid, M0, M11, N0, N11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
|
||||
make_pass_through_transform(M0),
|
||||
make_pass_through_transform(M11),
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(N11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; }
|
||||
|
||||
template <typename CM0M1N0N1ThreadDesc,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
|
||||
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
|
||||
a_k_m0_m1_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
|
||||
b_k_n0_n1_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
CM0M1N0N1ThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<1, M1PerThreadM11>,
|
||||
Sequence<1, N1PerThreadN11>>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of k
|
||||
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(k, I0, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(k, I0, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
|
||||
make_tuple(k, I1, I0),
|
||||
b_block_buf,
|
||||
b_k_n0_n1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
make_tuple(k, I1, I0),
|
||||
a_block_buf,
|
||||
a_k_m0_m1_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K, M0, M1]
|
||||
static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
|
||||
|
||||
// B[K, N0, N1]
|
||||
static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
decltype(a_k_m0_m1_block_desc_),
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThreadM11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M11,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
decltype(b_k_n0_n1_block_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThreadN11>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N11,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,410 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
|
||||
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ABlockDesc_BK0_BM_BK1,
|
||||
typename BBlockDesc_BK0_BN_BK1,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
|
||||
// BM10BN10ThreadClusterBM101, ...>
|
||||
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
|
||||
// BM10BN10ThreadClusterBN101, ...>
|
||||
index_t AThreadCopyScalarPerVector_BM11,
|
||||
index_t BThreadCopyScalarPerVector_BN11,
|
||||
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
|
||||
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
|
||||
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
|
||||
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
|
||||
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
|
||||
|
||||
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
|
||||
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
|
||||
|
||||
static constexpr index_t BM11 = BM1PerThreadBM11;
|
||||
static constexpr index_t BN11 = BN1PerThreadBN11;
|
||||
|
||||
static constexpr index_t BM1 = BM100 * BM101 * BM11;
|
||||
static constexpr index_t BN1 = BN100 * BN101 * BN11;
|
||||
|
||||
static constexpr index_t BM0 = BM / BM1;
|
||||
static constexpr index_t BN0 = BN / BN1;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
|
||||
{
|
||||
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
|
||||
a_block_desc_bk0_bm_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_block_bk0_bm0_bm1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
|
||||
{
|
||||
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
|
||||
b_block_desc_bk0_bn_bk1,
|
||||
make_tuple(make_pass_through_transform(Number<BK0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
|
||||
make_pass_through_transform(Number<BK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_block_desc_bk0_bn0_bn1_bk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
|
||||
{
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM, BN]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<BM0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
|
||||
make_pass_through_transform(Number<BN0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
|
||||
{
|
||||
return Sequence<BM0, BM11, BN0, BN11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
|
||||
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
|
||||
|
||||
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
|
||||
{
|
||||
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
|
||||
|
||||
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
|
||||
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
|
||||
BM10BN10ThreadClusterBN10Xs::Size() == 2,
|
||||
"wrong!");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(BM0 == 2 && BN0 == 2, "wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
|
||||
{
|
||||
// lower: [BM0, BM1, BN0, BN1]
|
||||
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
constexpr auto adaptor0 =
|
||||
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
|
||||
|
||||
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
|
||||
// upper: [Tid, BM0, BM11, BN0, BN11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
|
||||
make_pass_through_transform(BM0),
|
||||
make_pass_through_transform(BM11),
|
||||
make_pass_through_transform(BN0),
|
||||
make_pass_through_transform(BN11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(BM0 == 2 && BN0 == 2 &&
|
||||
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
|
||||
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_contraction =
|
||||
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
Sequence<BK0PerThread, BK1>,
|
||||
Sequence<1, BM1PerThreadBM11>,
|
||||
Sequence<1, BN1PerThreadBN11>>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of bk0
|
||||
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(bk0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(bk0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(bk0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(bk0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_contraction.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
// A[BK0, BM0, BM1, BK1]
|
||||
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
|
||||
|
||||
// B[BK0, BN0, BN1, BK1]
|
||||
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
|
||||
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
|
||||
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
|
||||
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,185 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_gemm_dlops_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename BlockMatrixA,
|
||||
typename BlockMatrixB,
|
||||
typename ThreadMatrixC,
|
||||
index_t KPerThread,
|
||||
index_t HPerThread,
|
||||
index_t WPerThread,
|
||||
index_t EPerThreadLoop,
|
||||
index_t ThreadGemmADataPerRead_K,
|
||||
index_t ThreadGemmBDataPerRead_W>
|
||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
index_t k;
|
||||
index_t h;
|
||||
index_t w;
|
||||
};
|
||||
|
||||
// HACK: fix this @Jing Zhang
|
||||
static constexpr index_t KPerThreadSubC = 4;
|
||||
|
||||
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
|
||||
|
||||
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||
|
||||
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
BlockMatrixA,
|
||||
decltype(a_thread_mtx_),
|
||||
Sequence<EPerThreadLoop, KPerThreadSubC>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ThreadGemmADataPerRead_K,
|
||||
1>;
|
||||
|
||||
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
|
||||
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
|
||||
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
|
||||
{
|
||||
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
|
||||
BlockMatrixB::IsKnownAtCompileTime() &&
|
||||
ThreadMatrixC::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent\n");
|
||||
|
||||
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
|
||||
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
|
||||
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
|
||||
|
||||
static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
|
||||
"wrong! Cannot evenly divide work among\n");
|
||||
|
||||
constexpr auto KThreadCluster = K / KPerThread;
|
||||
constexpr auto HThreadCluster = H / HPerThread;
|
||||
constexpr auto WThreadCluster = W / WPerThread;
|
||||
|
||||
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
|
||||
"wrong! wrong blocksize\n");
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetThreadMatrixCLengths()
|
||||
{
|
||||
return Sequence<KPerThread, 1, HPerThread, WPerThread>{};
|
||||
}
|
||||
|
||||
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
|
||||
{
|
||||
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{});
|
||||
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{});
|
||||
|
||||
constexpr auto num_w_threads = W / WPerThread;
|
||||
constexpr auto num_h_threads = H / HPerThread;
|
||||
constexpr auto num_hw_threads = num_w_threads * num_h_threads;
|
||||
|
||||
index_t k_thread_id = thread_id / num_hw_threads;
|
||||
index_t hw_thread_id = thread_id % num_hw_threads;
|
||||
|
||||
index_t h_thread_id = hw_thread_id / num_w_threads;
|
||||
index_t w_thread_id = hw_thread_id % num_w_threads;
|
||||
|
||||
return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BThreadBuffer& b_thread_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
|
||||
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
|
||||
|
||||
// HACK: fix this @Jing Zhang
|
||||
constexpr auto HoPerThreadSubC = 2;
|
||||
constexpr auto WoPerThreadSubC = 2;
|
||||
|
||||
static_assert(KPerThread % KPerThreadSubC == 0, "");
|
||||
static_assert(HPerThread % HoPerThreadSubC == 0, "");
|
||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||
|
||||
// thread A buffer for GEMM
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
|
||||
a_thread_buf;
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_mtx_),
|
||||
decltype(b_thread_mtx_),
|
||||
decltype(c_thread_mtx_),
|
||||
HoPerThreadSubC,
|
||||
WoPerThreadSubC>{};
|
||||
|
||||
static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
|
||||
static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) {
|
||||
a_thread_copy_.Run(a_block_mtx,
|
||||
make_tuple(e_begin, k_begin),
|
||||
a_block_buf,
|
||||
a_thread_mtx_,
|
||||
make_tuple(I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) {
|
||||
static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) {
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(e_begin, I0, h_begin, w_begin),
|
||||
c_thread_buf,
|
||||
make_tuple(k_begin, I0, h_begin, w_begin));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ABlockSliceMoveStepIdx>
|
||||
__device__ void MoveASliceWindow(const BlockMatrixA&,
|
||||
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
|
||||
{
|
||||
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
|
||||
}
|
||||
|
||||
private:
|
||||
MatrixIndex c_thread_begin_mtx_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,524 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "xdlops_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
class ABlockDesc,
|
||||
class BBlockDesc,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t K1>
|
||||
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
{
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
||||
|
||||
static constexpr index_t MWaves = M1 / MPerWave;
|
||||
static constexpr index_t NWaves = N1 / NPerWave;
|
||||
|
||||
static constexpr index_t MRepeat = M0;
|
||||
static constexpr index_t NRepeat = N0;
|
||||
|
||||
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
|
||||
|
||||
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
|
||||
|
||||
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
|
||||
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
|
||||
return make_tuple(k_offset, 0, m_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t m_offset = waveId_m * MPerWave + laneId;
|
||||
const index_t k_offset = 0;
|
||||
return make_tuple(k_offset, 0, m_offset, 0);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
|
||||
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
|
||||
return make_tuple(k_offset, 0, n_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t n_offset = waveId_n * NPerWave + laneId;
|
||||
const index_t k_offset = 0;
|
||||
return make_tuple(k_offset, 0, n_offset, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static CIndex
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
|
||||
const index_t waveId = get_thread_local_1d_id() / WaveSize;
|
||||
|
||||
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
|
||||
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
|
||||
|
||||
return CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1()
|
||||
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
|
||||
b_thread_copy_{CalculateBThreadOriginDataIndex()}
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
|
||||
"wrong! K1 dimension not consistent");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
|
||||
|
||||
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
vector_type<FloatAB, a_thread_desc_.GetElementSpaceSize()> a_thread_vec;
|
||||
|
||||
vector_type<FloatAB, b_thread_desc_.GetElementSpaceSize()> b_thread_vec;
|
||||
|
||||
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
|
||||
// read A
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type;
|
||||
|
||||
static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf[Number<i>{}];
|
||||
});
|
||||
|
||||
static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
|
||||
b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf[Number<i>{}];
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
m0,
|
||||
n0>(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, MRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
K1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, NRepeat, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
K1,
|
||||
1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
class ABlockDesc,
|
||||
class BBlockDesc,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t K1>
|
||||
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
{
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
|
||||
|
||||
static constexpr index_t WaveSize = 64;
|
||||
|
||||
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
static constexpr index_t MWaves = M1 / MPerWave;
|
||||
static constexpr index_t NWaves = N1 / NPerWave;
|
||||
|
||||
static constexpr index_t MRepeat = M0;
|
||||
static constexpr index_t NRepeat = N0;
|
||||
|
||||
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
|
||||
|
||||
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
|
||||
|
||||
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
|
||||
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
|
||||
return make_tuple(k_offset, 0, m_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t m_offset = waveId_m * MPerWave + laneId;
|
||||
const index_t k_offset = 0;
|
||||
return make_tuple(k_offset, 0, m_offset, 0);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
const index_t waveId = thread_id / WaveSize;
|
||||
const index_t laneId = thread_id % WaveSize;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
if constexpr(xdlops_gemm.IsKReduction)
|
||||
{
|
||||
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
|
||||
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
|
||||
return make_tuple(k_offset, 0, n_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t n_offset = waveId_n * NPerWave + laneId;
|
||||
const index_t k_offset = 0;
|
||||
return make_tuple(k_offset, 0, n_offset, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static CIndex
|
||||
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
|
||||
const index_t waveId = get_thread_local_1d_id() / WaveSize;
|
||||
|
||||
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
|
||||
|
||||
const index_t waveId_m = waveId / NWaves;
|
||||
const index_t waveId_n = waveId % NWaves;
|
||||
|
||||
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
|
||||
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
|
||||
|
||||
return CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline()
|
||||
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
|
||||
b_thread_copy_{CalculateBThreadOriginDataIndex()}
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
|
||||
"wrong! K1 dimension not consistent");
|
||||
|
||||
static_assert(BlockSize == MWaves * NWaves * WaveSize,
|
||||
"BlockSize != MWaves * NWaves * WaveSize\n");
|
||||
|
||||
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
|
||||
|
||||
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
0,
|
||||
0>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
0,
|
||||
1>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
static_for<xdlops_gemm.KPerXdlops, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
1,
|
||||
0>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
1,
|
||||
1>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
0,
|
||||
0>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
0,
|
||||
1>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
1,
|
||||
0>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
xdlops_gemm.template Run<decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
decltype(c_thread_desc_),
|
||||
1,
|
||||
1>(a_thread_buf, b_thread_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K, M]
|
||||
static constexpr auto a_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
// B[K, N]
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
|
||||
|
||||
static constexpr auto c_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // K1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
|
||||
FloatAB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<1, 1, 1, K1>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
1, // K1,
|
||||
1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,170 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v4
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin)
|
||||
: threadwise_transfer_(
|
||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& step,
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(
|
||||
src_desc, step, src_move_slice_window_step_hack);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorDim,
|
||||
DstVectorDim,
|
||||
SrcScalarPerVector,
|
||||
DstScalarPerVector,
|
||||
SrcScalarStrideInVector,
|
||||
DstScalarStrideInVector,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,156 @@
|
||||
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename DstVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename DstVectorTensorContiguousDimOrder,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin)
|
||||
: threadwise_transfer_(
|
||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
// SrcMoveSliceWindowStepHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& step,
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(
|
||||
src_desc, step, src_move_slice_window_step_hack);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorTensorLengths,
|
||||
DstVectorTensorLengths,
|
||||
SrcVectorTensorContiguousDimOrder,
|
||||
DstVectorTensorContiguousDimOrder,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,659 @@
|
||||
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseContraction,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
|
||||
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
|
||||
typename CGridBlockCluster_BlockId_To_GM10_GN10,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_contraction_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t GK0PerBlock,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs,
|
||||
typename BM10BN10ThreadClusterBN10Xs,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// GM0 and GN0 need to known at compile-time
|
||||
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
|
||||
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
|
||||
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// lds max alignment
|
||||
// TODO: part of them should be moved into blockwise-gemm
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = GK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
|
||||
"wrong! GM0 and GN0 need to be known at compile-time");
|
||||
|
||||
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
|
||||
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (
|
||||
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
|
||||
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
|
||||
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
|
||||
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
|
||||
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
|
||||
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
|
||||
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
|
||||
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
|
||||
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
|
||||
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
|
||||
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr index_t GM11 = GM1PerBlockGM11;
|
||||
constexpr index_t GN11 = GN1PerBlockGN11;
|
||||
|
||||
const index_t GM10 = GM1 / GM11;
|
||||
const index_t GN10 = GN1 / GN11;
|
||||
|
||||
const index_t grid_size = GM10 * GN10;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
|
||||
{
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
|
||||
|
||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
const auto GM10 = GM1 / GM11;
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
|
||||
{
|
||||
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
|
||||
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
|
||||
|
||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11)),
|
||||
make_pass_through_transform(GK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
constexpr auto BM = GM0 * GM11;
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies{}, I1) *
|
||||
BM1PerThreadBM11>{};
|
||||
constexpr auto BN1 =
|
||||
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies{}, I1) *
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
make_tuple(make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
|
||||
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_merge_transform(make_tuple(GM0, GM11)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_merge_transform(make_tuple(GN0, GN11))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
|
||||
c_gm10_bm_gn10_bn_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_unmerge_transform(make_tuple(BN0, BN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
|
||||
{
|
||||
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
|
||||
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_grid_block_cluster_blockid_to_gm10_gn10;
|
||||
}
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
|
||||
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
|
||||
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
|
||||
|
||||
// divide block work by [GM10, GN10]
|
||||
const auto c_gm10_gn10_block_cluster_idx =
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
|
||||
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
// TODO: part of them should be moved into blockwise-gemm
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = GK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// A matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
||||
|
||||
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
|
||||
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
|
||||
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
|
||||
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
make_multi_index(0, 0, igm10, 0, 0),
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
|
||||
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
make_multi_index(0, 0, ign10, 0, 0),
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
|
||||
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
|
||||
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_block_desc_gk0_bm_gk1),
|
||||
decltype(b_block_desc_gk0_bn_gk1),
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
|
||||
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
|
||||
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t gk0_block_on_grid = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
|
||||
|
||||
gk0_block_on_grid += 2 * GK0PerBlock;
|
||||
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
|
||||
I1,
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
|
||||
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
|
||||
|
||||
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
|
||||
Sequence<1,
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
|
||||
1,
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
|
||||
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
make_multi_index(igm10,
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
|
||||
ign10,
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
|
||||
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
|
||||
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,662 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r2.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AKM0M1GridDesc,
|
||||
typename BKN0N1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
|
||||
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by CONSTANT void pointer
|
||||
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AKM0M1GridDesc,
|
||||
typename BKN0N1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
// first cast void CONSTANT void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k_m0_m1_grid_desc = *reinterpret_cast<const AKM0M1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc));
|
||||
const auto b_k_n0_n1_grid_desc = *reinterpret_cast<const BKN0N1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc));
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlockM1,
|
||||
index_t NPerBlockN1,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N1>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K == b_k_n_grid_desc.GetLength(I0)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
|
||||
{
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
|
||||
a_k_m_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return a_k_m0_m1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
|
||||
{
|
||||
const auto K = b_k_n_grid_desc.GetLength(I0);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
|
||||
b_k_n_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
|
||||
|
||||
return b_k_n0_n1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
|
||||
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
|
||||
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
|
||||
Number<BBlockTransferDstScalarPerVector_N1>{},
|
||||
Number<M1PerThreadM111>{},
|
||||
Number<N1PerThreadN111>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k_m0_m1_grid_desc),
|
||||
decltype(a_k_m0_m1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k_m0_m1_grid_desc,
|
||||
make_multi_index(0, im0, 0),
|
||||
a_k_m0_m1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k_n0_n1_grid_desc),
|
||||
decltype(b_k_n0_n1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k_n0_n1_grid_desc,
|
||||
make_multi_index(0, in0, 0),
|
||||
b_k_n0_n1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
|
||||
AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
|
||||
BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k_m0_m1_global_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k_n0_n1_global_move_slice_window_step_hack);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
make_multi_index(im0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
in0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,650 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
|
||||
#define CK_GRIDWISE_GEMM_V1R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0M0M1K1GridDesc,
|
||||
typename BK0N0N1K1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dlops_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
|
||||
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by CONSTANT void pointer
|
||||
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0M0M1K1GridDesc,
|
||||
typename BK0N0N1K1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
// first cast void CONSTANT void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast<const AK0M0M1K1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc));
|
||||
const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast<const BK0N0N1K1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc));
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlockM1,
|
||||
index_t NPerBlockN1,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
typename M11N11ThreadClusterM110Xs,
|
||||
typename M11N11ThreadClusterN110Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
|
||||
{
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
transform_tensor_descriptor(a_k0_m_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_k0_m0_m1_k1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
|
||||
{
|
||||
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
transform_tensor_descriptor(b_k0_n_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_k0_n0_n1_k1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
|
||||
M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
|
||||
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{}));
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
|
||||
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
|
||||
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() ==
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_grid_desc),
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_k0_m0_m1_k1_grid_desc,
|
||||
make_multi_index(0, im0, 0, 0),
|
||||
a_k0_m0_m1_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n0_n1_k1_grid_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_k0_n0_n1_k1_grid_desc,
|
||||
make_multi_index(0, in0, 0, 0),
|
||||
b_k0_n0_n1_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0);
|
||||
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K0 - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(
|
||||
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(
|
||||
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
make_multi_index(im0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
in0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridStepHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,458 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_V2_HPP
|
||||
#define CK_GRIDWISE_GEMM_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "blockwise_gemm_dlops_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
index_t KPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGlobalStepHacks,
|
||||
typename BGlobalStepHacks,
|
||||
typename CGlobalStepHacks,
|
||||
typename AGlobalMoveSliceWindowStepHacks,
|
||||
typename BGlobalMoveSliceWindowStepHacks>
|
||||
struct GridwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
|
||||
constexpr auto max_lds_align =
|
||||
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return a_block_space_size * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc& b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
|
||||
// const auto E = a_e_k_global_desc.GetLength(I0);
|
||||
const auto K = a_e_k_global_desc.GetLength(I1);
|
||||
|
||||
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
|
||||
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
// divide block work by [M, N]
|
||||
#if 0
|
||||
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
|
||||
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
|
||||
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
|
||||
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
|
||||
|
||||
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
|
||||
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
||||
#else
|
||||
// Hack: this force result into SGPR
|
||||
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
|
||||
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
|
||||
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
|
||||
|
||||
const index_t k_block_work_id =
|
||||
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
|
||||
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
|
||||
|
||||
const index_t ho_block_work_id =
|
||||
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
|
||||
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
|
||||
#endif
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align =
|
||||
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_e_k_block_desc),
|
||||
decltype(b_e_n_ho_wo_block_desc),
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K>{};
|
||||
|
||||
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const auto k_thread_id = c_thread_mtx_index.k;
|
||||
const auto ho_thread_id = c_thread_mtx_index.h;
|
||||
const auto wo_thread_id = c_thread_mtx_index.w;
|
||||
|
||||
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
|
||||
|
||||
const index_t ho_thread_data_on_global =
|
||||
ho_block_data_on_global + ho_thread_id * HoPerThread;
|
||||
const index_t wo_thread_data_on_global =
|
||||
wo_block_data_on_global + wo_thread_id * WoPerThread;
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<E, KPerBlock>,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_e_k_global_desc),
|
||||
decltype(a_e_k_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_e_k_global_desc,
|
||||
make_multi_index(0, k_block_data_on_global),
|
||||
a_e_k_desc,
|
||||
make_multi_index(0, 0));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto b_threadwise_transfer =
|
||||
ThreadwiseTensorSliceTransfer_v2<FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_e_n_ho_wo_global_desc),
|
||||
decltype(b_e_n_ho_wo_thread_desc),
|
||||
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
1,
|
||||
true>(
|
||||
b_e_n_ho_wo_global_desc,
|
||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_shared_block, a_e_k_desc.GetElementSpaceSize());
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAcc,
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
|
||||
// initialize output thread tensor
|
||||
ThreadwiseTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
|
||||
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
|
||||
|
||||
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
BGlobalMoveSliceWindowStepHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAB,
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
b_thread_even_buf, b_thread_odd_buf;
|
||||
|
||||
// LDS double buffer: preload data
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t e_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_even_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
e_block_data_begin += 2 * EPerBlock;
|
||||
|
||||
} while(e_block_data_begin < E - 2 * EPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
|
||||
b_thread_slice_copy_step);
|
||||
|
||||
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
|
||||
b_global_buf,
|
||||
b_e_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_odd_buf,
|
||||
b_e_n_ho_wo_global_step_hacks);
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
|
||||
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + k_thread_id * KPerThread;
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
decltype(c_k_n_ho_wo_global_desc),
|
||||
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>(
|
||||
c_k_n_ho_wo_global_desc,
|
||||
make_multi_index(
|
||||
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
|
||||
.Run(c_k_n_ho_wo_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
c_global_buf,
|
||||
c_k_n_ho_wo_global_tensor_step_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
// pass tensor descriptor by reference
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc& b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc& c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
p_shared_block,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
// pass tensor descriptors by their pointers
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
|
||||
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
|
||||
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
// pass tensor descriptors by void*
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ void Run(const void* p_a_e_k_global_desc,
|
||||
const FloatAB* __restrict__ p_a_global,
|
||||
const void* p_b_e_n_ho_wo_global_desc,
|
||||
const FloatAB* __restrict__ p_b_global,
|
||||
const void* p_c_k_n_ho_wo_global_desc,
|
||||
FloatC* __restrict__ p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>) const
|
||||
{
|
||||
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
|
||||
const auto b_e_n_ho_wo_global_desc =
|
||||
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
|
||||
const auto c_k_n_ho_wo_global_desc =
|
||||
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
|
||||
|
||||
Run(a_e_k_global_desc,
|
||||
p_a_global,
|
||||
b_e_n_ho_wo_global_desc,
|
||||
p_b_global,
|
||||
c_k_n_ho_wo_global_desc,
|
||||
p_c_global,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,799 @@
|
||||
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
|
||||
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "multi_index_transform_helper.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_xdlops.hpp"
|
||||
#include "blockwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0M1M2NGridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
|
||||
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
|
||||
const CBlockClusterAdaptor c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CM0M1M2NGridDesc,
|
||||
typename CBlockClusterAdaptor>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = *reinterpret_cast<const AK0MK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
|
||||
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
|
||||
const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>(
|
||||
cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc));
|
||||
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
|
||||
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t K1Value,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_space_size =
|
||||
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
// TODO: turn on this
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
|
||||
"wrong! K1 need to be known at compile-time");
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
|
||||
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
|
||||
|
||||
constexpr auto CLayout = xdlops_gemm.GetCLayout();
|
||||
|
||||
constexpr auto M0 = Number<CLayout.M1()>{};
|
||||
constexpr auto M1 = Number<CLayout.N1()>{};
|
||||
constexpr auto M2 = Number<CLayout.M0()>{};
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||
|
||||
constexpr auto N1 = Number<CLayout.N0()>{};
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
|
||||
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
|
||||
|
||||
return c_m0_m1_m2_n_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlock>{};
|
||||
constexpr auto N1 = Number<NPerBlock>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
#if 1
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
#elif 1
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))),
|
||||
make_tuple(Sequence<1, 0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
#endif
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{}));
|
||||
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
|
||||
const CBlockClusterAdaptor& c_block_cluster_adaptor)
|
||||
{
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_work_idx =
|
||||
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
|
||||
// lds max alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m_k1_grid_desc),
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(a_k0_m_k1_grid_desc,
|
||||
make_multi_index(0, m_block_data_idx_on_grid, 0),
|
||||
a_k0_m_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n_k1_grid_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<1, 0, 2>,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
1,
|
||||
1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
true>(b_k0_n_k1_grid_desc,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
b_k0_n_k1_block_desc,
|
||||
make_multi_index(0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlock] is in LDS
|
||||
// b_mtx[KPerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
|
||||
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
|
||||
NPerBlock % (NPerWave * NRepeat) == 0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
|
||||
a_k0_m_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
|
||||
b_k0_n_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1>{};
|
||||
|
||||
constexpr auto CLayout = blockwise_gemm.GetCLayout();
|
||||
|
||||
constexpr index_t BlkSize = CLayout.GetBlkSize();
|
||||
constexpr index_t NumBlks = CLayout.GetNumBlks();
|
||||
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
|
||||
|
||||
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
|
||||
|
||||
constexpr auto c_mr_nr_blk_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
vector_type<FloatAcc, BlkSize>,
|
||||
c_mr_nr_blk_desc.GetElementSpaceSize(),
|
||||
true>
|
||||
c_thread_buf;
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size =
|
||||
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block = p_shared_block;
|
||||
FloatAB* p_b_block = p_shared_block + a_block_space_size;
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
|
||||
|
||||
// hack to control index calculation when iterating over A and B matrix for threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
|
||||
|
||||
// hack to control index calculation when move slice window for A and B matrix for
|
||||
// threadwise copy
|
||||
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
}
|
||||
|
||||
// main body
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
a_k0_m_k1_grid_move_slice_window_step_hack);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
b_k0_n_k1_grid_move_slice_window_step_hack);
|
||||
|
||||
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
|
||||
|
||||
k_block_data_begin += KPerBlock;
|
||||
} while(k_block_data_begin < (K0 - KPerBlock));
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
#if 0
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
|
||||
constexpr index_t N0 = CLayout.N1();
|
||||
constexpr index_t N1 = CLayout.N0();
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
|
||||
Number<NRepeat>{},
|
||||
Number<1>{},
|
||||
Number<1>{},
|
||||
Number<M0>{},
|
||||
Number<1>{},
|
||||
Number<M2>{},
|
||||
Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
|
||||
c_blk_buf_;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
|
||||
constexpr auto blk_off =
|
||||
c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i));
|
||||
|
||||
static_for<0, BlkSize, 1>{}([&](auto j) {
|
||||
c_blk_buf_(Number<blk_off * BlkSize + j>{}) =
|
||||
c_thread_buf[Number<blk_off>{}]
|
||||
.template AsType<FloatAcc>()[Number<j>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
|
||||
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
|
||||
|
||||
ThreadwiseTensorSliceTransfer_v1r3<
|
||||
FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_grid_desc),
|
||||
Sequence<MRepeat, NRepeat, 1, 1, M0, 1, M2, 1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves),
|
||||
n_thread_data_on_grid / (N1 * NWaves),
|
||||
m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0),
|
||||
n_thread_data_on_grid % (N1 * NWaves) / N1,
|
||||
m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1),
|
||||
m_thread_data_on_grid % (M2 * M1) / M2,
|
||||
m_thread_data_on_grid % M2,
|
||||
n_thread_data_on_grid % N1)}
|
||||
.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_blk_buf_,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
}
|
||||
#else
|
||||
{
|
||||
constexpr index_t M0 = CLayout.M1();
|
||||
constexpr index_t M1 = CLayout.N1();
|
||||
constexpr index_t M2 = CLayout.M0();
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
|
||||
|
||||
const index_t m_thread_data_on_grid =
|
||||
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
|
||||
|
||||
const index_t n_thread_data_on_grid =
|
||||
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
|
||||
|
||||
auto c_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
|
||||
FloatC,
|
||||
decltype(c_m0_m1_m2_n_thread_desc),
|
||||
decltype(c_m0_m1_m2_n_grid_desc),
|
||||
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
make_multi_index(0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
m_thread_data_on_grid / (M2 * M1),
|
||||
m_thread_data_on_grid % (M2 * M1) / M2,
|
||||
m_thread_data_on_grid % M2,
|
||||
n_thread_data_on_grid)};
|
||||
|
||||
auto init_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
|
||||
return c_thread_idx_;
|
||||
};
|
||||
|
||||
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
|
||||
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
|
||||
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus);
|
||||
|
||||
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
|
||||
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_grid_buf,
|
||||
c_m0_m1_m2_n_grid_tensor_step_hacks);
|
||||
};
|
||||
|
||||
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
|
||||
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
|
||||
(MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or
|
||||
(MRepeat == 1 && NRepeat == 1),
|
||||
"wrong");
|
||||
|
||||
if constexpr(MRepeat == 4 && NRepeat == 4)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
nrepeat_plus_copy(make_tuple(I0, I3));
|
||||
mrepeat_plus_copy(make_tuple(I1, I3));
|
||||
nrepeat_minus_copy(make_tuple(I1, I2));
|
||||
nrepeat_minus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
nrepeat_plus_copy(make_tuple(I2, I1));
|
||||
nrepeat_plus_copy(make_tuple(I2, I2));
|
||||
nrepeat_plus_copy(make_tuple(I2, I3));
|
||||
mrepeat_plus_copy(make_tuple(I3, I3));
|
||||
nrepeat_minus_copy(make_tuple(I3, I2));
|
||||
nrepeat_minus_copy(make_tuple(I3, I1));
|
||||
nrepeat_minus_copy(make_tuple(I3, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
mrepeat_plus_copy(make_tuple(I3, I0));
|
||||
nrepeat_plus_copy(make_tuple(I3, I1));
|
||||
mrepeat_minus_copy(make_tuple(I2, I1));
|
||||
mrepeat_minus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
mrepeat_plus_copy(make_tuple(I1, I2));
|
||||
mrepeat_plus_copy(make_tuple(I2, I2));
|
||||
mrepeat_plus_copy(make_tuple(I3, I2));
|
||||
nrepeat_plus_copy(make_tuple(I3, I3));
|
||||
mrepeat_minus_copy(make_tuple(I2, I3));
|
||||
mrepeat_minus_copy(make_tuple(I1, I3));
|
||||
mrepeat_minus_copy(make_tuple(I0, I3));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 4 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
mrepeat_plus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
nrepeat_plus_copy(make_tuple(I2, I1));
|
||||
mrepeat_plus_copy(make_tuple(I3, I1));
|
||||
nrepeat_minus_copy(make_tuple(I3, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
mrepeat_plus_copy(make_tuple(I2, I0));
|
||||
mrepeat_plus_copy(make_tuple(I3, I0));
|
||||
nrepeat_plus_copy(make_tuple(I3, I1));
|
||||
mrepeat_minus_copy(make_tuple(I2, I1));
|
||||
mrepeat_minus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 4)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
nrepeat_plus_copy(make_tuple(I0, I3));
|
||||
mrepeat_plus_copy(make_tuple(I1, I3));
|
||||
nrepeat_minus_copy(make_tuple(I1, I2));
|
||||
nrepeat_minus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
nrepeat_plus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
nrepeat_plus_copy(make_tuple(I0, I2));
|
||||
mrepeat_plus_copy(make_tuple(I1, I2));
|
||||
nrepeat_plus_copy(make_tuple(I1, I3));
|
||||
mrepeat_minus_copy(make_tuple(I0, I3));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
|
||||
if constexpr(CAccessOrderMRepeatNRepeat)
|
||||
{
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
mrepeat_plus_copy(make_tuple(I1, I1));
|
||||
nrepeat_minus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else
|
||||
{
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
nrepeat_plus_copy(make_tuple(I1, I1));
|
||||
mrepeat_minus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
}
|
||||
else if constexpr(MRepeat == 2 && NRepeat == 1)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
mrepeat_plus_copy(make_tuple(I1, I0));
|
||||
}
|
||||
else if constexpr(MRepeat == 1 && NRepeat == 2)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
nrepeat_plus_copy(make_tuple(I0, I1));
|
||||
}
|
||||
else if constexpr(MRepeat == 1 && NRepeat == 1)
|
||||
{
|
||||
init_copy(make_tuple(I0, I0));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}; // namespace ck
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,229 @@
|
||||
#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP
|
||||
#define CK_THREADWISE_CONTRACTION_DLOPS_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1]
|
||||
// Tensor element can be vectorized data
|
||||
// Assume:
|
||||
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
|
||||
// known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AThreadDesc_TK0_TM0_TM1_TK1,
|
||||
typename BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
typename CThreadDesc_TM0_TM1_TN0_TN1,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
|
||||
{
|
||||
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto TK = TKLengths{}[I0];
|
||||
constexpr auto TM0 = TMLengths{}[I0];
|
||||
constexpr auto TM1 = TMLengths{}[I1];
|
||||
constexpr auto TN0 = TNLengths{}[I0];
|
||||
constexpr auto TN1 = TNLengths{}[I1];
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, TK, 1>{}([&](auto tk) {
|
||||
static_for<0, TM0, 1>{}([&](auto tm0) {
|
||||
static_for<0, TM1, 1>{}([&](auto tm1) {
|
||||
static_for<0, TN0, 1>{}([&](auto tn0) {
|
||||
static_for<0, TN1, 1>{}([&](auto tn1) {
|
||||
constexpr index_t a_offset =
|
||||
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk, tm0, tm1));
|
||||
constexpr index_t b_offset =
|
||||
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk, tn0, tn1));
|
||||
constexpr index_t c_offset =
|
||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1]
|
||||
// Tensor element can be vectorized data
|
||||
// Assume:
|
||||
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
|
||||
// known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AThreadDesc_TK0_TM0_TM1_TK1,
|
||||
typename BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
typename CThreadDesc_TM0_TM1_TN0_TN1,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
||||
{
|
||||
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
||||
{
|
||||
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
|
||||
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t TK0 = TKLengths{}[I0];
|
||||
constexpr index_t TK1 = TKLengths{}[I1];
|
||||
constexpr index_t TM0 = TMLengths{}[I0];
|
||||
constexpr index_t TM1 = TMLengths{}[I1];
|
||||
constexpr index_t TN0 = TNLengths{}[I0];
|
||||
constexpr index_t TN1 = TNLengths{}[I1];
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, TK0, 1>{}([&](auto tk0) {
|
||||
static_for<0, TM0, 1>{}([&](auto tm0) {
|
||||
static_for<0, TM1, 1>{}([&](auto tm1) {
|
||||
static_for<0, TN0, 1>{}([&](auto tn0) {
|
||||
static_for<0, TN1, 1>{}([&](auto tn1) {
|
||||
vector_type<FloatA, TK1> a_vec;
|
||||
vector_type<FloatB, TK1> b_vec;
|
||||
|
||||
static_for<0, TK1, 1>{}([&](auto tk1) {
|
||||
constexpr index_t a_offset =
|
||||
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
|
||||
|
||||
constexpr index_t b_offset =
|
||||
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
|
||||
|
||||
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
|
||||
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
|
||||
});
|
||||
|
||||
using a_vector_t = typename vector_type<FloatA, TK1>::type;
|
||||
using b_vector_t = typename vector_type<FloatB, TK1>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
inner_product<a_vector_t, b_vector_t, FloatC>(
|
||||
a_vec.template AsType<a_vector_t>()[I0],
|
||||
b_vec.template AsType<b_vector_t>()[I0],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,160 @@
|
||||
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
|
||||
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M, N] += transpose(A[K, M]) * B[K, N]
|
||||
// Element of matrix can be vectorized data
|
||||
// Assume:
|
||||
// 1. ADesc, BDesc, CDesc are known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
index_t H,
|
||||
index_t W,
|
||||
typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto E = ADesc{}.GetLength(I0);
|
||||
constexpr auto K = ADesc{}.GetLength(I1);
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, E, 1>{}([&](auto e) {
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
constexpr index_t a_offset =
|
||||
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k));
|
||||
|
||||
if constexpr(H == 2 && W == 2)
|
||||
{
|
||||
constexpr index_t b_offset_0 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
|
||||
constexpr index_t b_offset_1 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1));
|
||||
constexpr index_t b_offset_2 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
|
||||
constexpr index_t b_offset_3 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1));
|
||||
|
||||
constexpr index_t c_offset_0 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
|
||||
constexpr index_t c_offset_1 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1));
|
||||
constexpr index_t c_offset_2 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
|
||||
constexpr index_t c_offset_3 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1));
|
||||
|
||||
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset_0>{}],
|
||||
b_buf[Number<b_offset_1>{}],
|
||||
b_buf[Number<b_offset_2>{}],
|
||||
b_buf[Number<b_offset_3>{}],
|
||||
c_buf(Number<c_offset_0>{}),
|
||||
c_buf(Number<c_offset_1>{}),
|
||||
c_buf(Number<c_offset_2>{}),
|
||||
c_buf(Number<c_offset_3>{}));
|
||||
}
|
||||
else if constexpr(H == 4 && W == 1)
|
||||
{
|
||||
constexpr index_t b_offset_0 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
|
||||
constexpr index_t b_offset_1 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
|
||||
constexpr index_t b_offset_2 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0));
|
||||
constexpr index_t b_offset_3 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0));
|
||||
|
||||
constexpr index_t c_offset_0 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
|
||||
constexpr index_t c_offset_1 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
|
||||
constexpr index_t c_offset_2 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0));
|
||||
constexpr index_t c_offset_3 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0));
|
||||
|
||||
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset_0>{}],
|
||||
b_buf[Number<b_offset_1>{}],
|
||||
b_buf[Number<b_offset_2>{}],
|
||||
b_buf[Number<b_offset_3>{}],
|
||||
c_buf(Number<c_offset_0>{}),
|
||||
c_buf(Number<c_offset_1>{}),
|
||||
c_buf(Number<c_offset_2>{}),
|
||||
c_buf(Number<c_offset_3>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, H, 1>{}([&](auto h) {
|
||||
static_for<0, W, 1>{}([&](auto w) {
|
||||
constexpr index_t b_offset =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w));
|
||||
|
||||
constexpr index_t c_offset =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w));
|
||||
|
||||
#if 0
|
||||
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
|
||||
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
|
||||
#else
|
||||
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,59 @@
|
||||
#ifndef CK_THREADWISE_TENSOR_SET_HPP
|
||||
#define CK_THREADWISE_TENSOR_SET_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Assume:
|
||||
// 1. Desc is known at compile-time
|
||||
// 2. Buffer is StaticBuffer
|
||||
// 3. OriginIdx is known at compile-time
|
||||
// 4. use #-step
|
||||
template <typename Data,
|
||||
typename Desc,
|
||||
typename SliceLengths,
|
||||
typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceSet_v1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
template <typename OriginIdx, typename Buffer>
|
||||
__device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const
|
||||
{
|
||||
static_assert(Desc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value,
|
||||
"wrong! OriginIdx need to be known at compile-time");
|
||||
|
||||
// Desc is known at compile-time
|
||||
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{};
|
||||
|
||||
// OriginIdx is known at compile-time
|
||||
constexpr auto origin_idx = to_multi_index(OriginIdx{});
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto access_idx) {
|
||||
constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
|
||||
|
||||
constexpr bool is_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
|
||||
|
||||
constexpr index_t offset = coord.GetOffset();
|
||||
|
||||
if constexpr(is_valid)
|
||||
{
|
||||
buf(Number<offset>{}) = initial_value;
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,779 @@
|
||||
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename DstVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename DstVectorTensorContiguousDimOrder,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
// RunRead(), will be fused with MoveSrcSliceWindow to
|
||||
// save addr computation
|
||||
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
{
|
||||
// TODO: fix this
|
||||
static_assert(is_same<SrcData, DstData>::value,
|
||||
"wrong! current implementation assume SrcData and DstData are same type");
|
||||
|
||||
static_for<0, nDim, 1>{}([](auto i) {
|
||||
static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 &&
|
||||
SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0,
|
||||
"wrong!");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcStepHacks>
|
||||
__device__ void
|
||||
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
// tensor descriptor for src_vector
|
||||
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
|
||||
|
||||
constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(src_vector_tensor_lengths,
|
||||
SrcVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies{},
|
||||
I1),
|
||||
SrcVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto src_vector_desc =
|
||||
make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||
|
||||
// access order and lengths
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto src_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, forward_step_idx, src_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto src_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
src_desc, backward_step_idx, src_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
}();
|
||||
|
||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_vector)::type;
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
// copy data from src_buf to src_vector
|
||||
src_vector.template AsType<src_vector_t>()(I0) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
|
||||
|
||||
// copy data from src_vector to buffer_
|
||||
static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
|
||||
constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
|
||||
|
||||
constexpr index_t src_vector_offset =
|
||||
src_vector_desc.CalculateOffset(src_vector_idx);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx);
|
||||
|
||||
buffer_(Number<buffer_offset>{}) =
|
||||
src_vector.template AsType<SrcData>()[Number<src_vector_offset>{}];
|
||||
});
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstStepHacks>
|
||||
__device__ void
|
||||
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
// tensor descriptor for dst_vector
|
||||
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
|
||||
|
||||
constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new(
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(dst_vector_tensor_lengths,
|
||||
DstVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies{},
|
||||
I1),
|
||||
DstVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto dst_vector_desc =
|
||||
make_naive_tensor_descriptor(sequence_to_tuple_of_number(dst_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(dst_vector_tensor_strides));
|
||||
|
||||
// dst access order and lengths
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward steps
|
||||
const auto dst_forward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward steps
|
||||
const auto dst_backward_steps = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(
|
||||
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths[i] - 1 -
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_vector_tensor_lengths;
|
||||
}();
|
||||
|
||||
vector_type_maker_t<DstData, dst_vector_desc.GetElementSpaceSize()> dst_vector;
|
||||
|
||||
// copy data from buffer_ to dst_vector (also cast from SrcData to DstData)
|
||||
static_ford<DstVectorTensorLengths>{}([&](auto dst_vector_idx_) {
|
||||
constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx);
|
||||
|
||||
constexpr index_t dst_vector_offset =
|
||||
dst_vector_desc.CalculateOffset(dst_vector_idx);
|
||||
|
||||
dst_vector.template AsType<DstData>()(Number<dst_vector_offset>{}) =
|
||||
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]);
|
||||
});
|
||||
|
||||
using dst_vector_t = typename decltype(dst_vector)::type;
|
||||
|
||||
// copy data from dst_vector to dst_buf
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_step =
|
||||
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunRead(src_desc, src_buf, src_step_hacks);
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_step_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunWrite(dst_desc, dst_buf, dst_step_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_vector_tensor_lengths;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by RunWrite(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto buffer_desc_ =
|
||||
make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{}));
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_ref_idx is known at run-time
|
||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||
// 5. use #-step
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. DstOriginIdx is known at compile-time
|
||||
// 4. use direct address calculation
|
||||
// 3. vector access on src
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_for<0, nDim, 1>{}([](auto i) {
|
||||
static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!");
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcRefToOriginDisplacement,
|
||||
typename DstOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcRefToOriginDisplacement&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstOriginIdx&,
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<
|
||||
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
|
||||
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
|
||||
"at compile-time");
|
||||
|
||||
// SrcDesc and DstDesc are known at compile-time
|
||||
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
|
||||
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
|
||||
|
||||
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
|
||||
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
|
||||
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
|
||||
|
||||
// tensor descriptor for src_vector
|
||||
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
|
||||
|
||||
constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(src_vector_tensor_lengths,
|
||||
SrcVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies{},
|
||||
I1),
|
||||
SrcVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto src_vector_desc =
|
||||
make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||
|
||||
// access order and lengths
|
||||
constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// position in slice window
|
||||
constexpr auto data_to_origin_disp_idx =
|
||||
ordered_access_idx.ReorderGivenOld2New(dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
|
||||
// src coordinate at starting point of src_vector
|
||||
constexpr auto src_ref_to_data_disp_idx =
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||
|
||||
constexpr auto src_ref_to_data_disp_coord_step =
|
||||
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
|
||||
|
||||
auto src_data_coord = src_ref_coord_;
|
||||
|
||||
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
|
||||
|
||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_vector)::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_data_coord);
|
||||
|
||||
// copy data from src_buf into src_vector
|
||||
src_vector.template AsType<src_vector_t>()(I0) =
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
|
||||
|
||||
// copy data from src_vector into dst_buf (also cast from SrcData to DstData)
|
||||
static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
|
||||
constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
|
||||
|
||||
constexpr index_t src_vector_offset =
|
||||
src_vector_desc.CalculateOffset(src_vector_idx);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = type_convert<DstData>{}(
|
||||
src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcSliceMoveStepIdx>
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc&,
|
||||
const SrcSliceMoveStepIdx& src_slice_move_step_idx)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
|
||||
const auto src_slice_move_step_iter =
|
||||
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
|
||||
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord src_ref_coord_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
801
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
Normal file
801
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
Normal file
@@ -0,0 +1,801 @@
|
||||
#ifndef CK_XDLOPS_GEMM_HPP
|
||||
#define CK_XDLOPS_GEMM_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "math.hpp"
|
||||
#include "amd_xdlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct mfma_instr
|
||||
{
|
||||
/// fp32
|
||||
mfma_f32_32x32x1xf32 = 0,
|
||||
mfma_f32_16x16x1xf32,
|
||||
mfma_f32_4x4x1xf32,
|
||||
mfma_f32_32x32x2xf32, // k reduction
|
||||
mfma_f32_16x16x4xf32, // k reduction
|
||||
/// fp16
|
||||
mfma_f32_32x32x4f16,
|
||||
mfma_f32_16x16x4f16,
|
||||
mfma_f32_4x4x4f16,
|
||||
mfma_f32_32x32x8f16, // k reduction
|
||||
mfma_f32_16x16x16f16, // k reduction
|
||||
/// bfp16
|
||||
mfma_f32_32x32x2bf16,
|
||||
mfma_f32_16x16x2bf16,
|
||||
mfma_f32_4x4x2bf16,
|
||||
mfma_f32_32x32x4bf16, // k reduction
|
||||
mfma_f32_16x16x8bf16, // k reduction
|
||||
};
|
||||
|
||||
template <mfma_instr instr>
|
||||
struct mfma_info;
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 4;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 2;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 32;
|
||||
static constexpr index_t n = 32;
|
||||
static constexpr index_t k = 1;
|
||||
static constexpr index_t cycles = 64;
|
||||
static constexpr index_t k_base = 1;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 4;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 32;
|
||||
static constexpr index_t n = 32;
|
||||
static constexpr index_t k = 2;
|
||||
static constexpr index_t cycles = 64;
|
||||
static constexpr index_t k_base = 1;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 16;
|
||||
static constexpr index_t n = 16;
|
||||
static constexpr index_t k = 4;
|
||||
static constexpr index_t cycles = 32;
|
||||
static constexpr index_t k_base = 1;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 4;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 16;
|
||||
static constexpr index_t n = 16;
|
||||
static constexpr index_t k = 1;
|
||||
static constexpr index_t cycles = 32;
|
||||
static constexpr index_t k_base = 1;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
// treat 4x4x1 as a single-blk 4x64 mfma
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 64;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = 4;
|
||||
static constexpr index_t m = 4;
|
||||
static constexpr index_t n = 64;
|
||||
static constexpr index_t k = 1;
|
||||
static constexpr index_t cycles = 8;
|
||||
static constexpr index_t k_base = 1;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 4;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 2;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 32;
|
||||
static constexpr index_t n = 32;
|
||||
static constexpr index_t k = 4;
|
||||
static constexpr index_t cycles = 64;
|
||||
static constexpr index_t k_base = 4;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 4;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 32;
|
||||
static constexpr index_t n = 32;
|
||||
static constexpr index_t k = 8;
|
||||
static constexpr index_t cycles = 64;
|
||||
static constexpr index_t k_base = 4;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 16;
|
||||
static constexpr index_t n = 16;
|
||||
static constexpr index_t k = 16;
|
||||
static constexpr index_t cycles = 32;
|
||||
static constexpr index_t k_base = 4;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 4;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 16;
|
||||
static constexpr index_t n = 16;
|
||||
static constexpr index_t k = 4;
|
||||
static constexpr index_t cycles = 32;
|
||||
static constexpr index_t k_base = 4;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 64;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = 4;
|
||||
static constexpr index_t m = 4;
|
||||
static constexpr index_t n = 64;
|
||||
static constexpr index_t k = 4;
|
||||
static constexpr index_t cycles = 8;
|
||||
static constexpr index_t k_base = 4;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t COffset,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 4;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 2;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 32;
|
||||
static constexpr index_t n = 32;
|
||||
static constexpr index_t k = 2;
|
||||
static constexpr index_t cycles = 64;
|
||||
static constexpr index_t k_base = 2;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
|
||||
p_a, p_b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 4;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 32;
|
||||
static constexpr index_t n = 32;
|
||||
static constexpr index_t k = 4;
|
||||
static constexpr index_t cycles = 64;
|
||||
static constexpr index_t k_base = 2;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 16;
|
||||
static constexpr index_t n = 16;
|
||||
static constexpr index_t k = 8;
|
||||
static constexpr index_t cycles = 32;
|
||||
static constexpr index_t k_base = 2;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
|
||||
static constexpr index_t num_output_blks = 4;
|
||||
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
|
||||
static constexpr index_t m = 16;
|
||||
static constexpr index_t n = 16;
|
||||
static constexpr index_t k = 2;
|
||||
static constexpr index_t cycles = 32;
|
||||
static constexpr index_t k_base = 2;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_blk = 1;
|
||||
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
|
||||
static constexpr index_t num_threads_blk = 64;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t num_regs_xdlops = 4;
|
||||
static constexpr index_t m = 4;
|
||||
static constexpr index_t n = 64;
|
||||
static constexpr index_t k = 2;
|
||||
static constexpr index_t cycles = 8;
|
||||
static constexpr index_t k_base = 2;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <mfma_instr instr, index_t MPerXdlops_, index_t NPerXdlops_>
|
||||
struct xdlops_info
|
||||
{
|
||||
static constexpr auto mfma_type = mfma_info<instr>{};
|
||||
|
||||
static constexpr index_t MPerXdlops = MPerXdlops_;
|
||||
static constexpr index_t NPerXdlops = NPerXdlops_;
|
||||
|
||||
static constexpr bool IsABroadcast()
|
||||
{
|
||||
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
|
||||
return true;
|
||||
}
|
||||
|
||||
static constexpr bool IsKReduction()
|
||||
{
|
||||
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
|
||||
}
|
||||
|
||||
static constexpr index_t GetKPerXdlops()
|
||||
{
|
||||
return IsKReduction() ? mfma_type.num_input_blks : 1;
|
||||
}
|
||||
|
||||
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
|
||||
};
|
||||
|
||||
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
|
||||
struct XdlopsGemm
|
||||
{
|
||||
template <class base_type_ = base_type,
|
||||
index_t MPerWave_ = MPerWave,
|
||||
index_t NPerWave_ = NPerWave>
|
||||
static constexpr auto GetXdlopsInfo();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<float, 64, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<float, 32, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<float, 16, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<float, 8, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<float, 4, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<float, 32, 32>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<float, 16, 16>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64>{};
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 64, 128>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 64, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 64, 32>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 32, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 64, 16>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 16, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 8, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
|
||||
{
|
||||
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
|
||||
}
|
||||
#endif
|
||||
|
||||
using CIndex = MultiIndex<2>;
|
||||
|
||||
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
|
||||
|
||||
__device__ static constexpr index_t GetNumXdlops()
|
||||
{
|
||||
return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr XdlopsGemm()
|
||||
{
|
||||
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
|
||||
NPerXdlops == 64,
|
||||
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
|
||||
|
||||
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
|
||||
MPerXdlops == 64,
|
||||
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
|
||||
|
||||
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
|
||||
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
|
||||
"m != num_input_blks * num_regs_blk");
|
||||
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
|
||||
mfma_type.num_output_blks == 1,
|
||||
"incorrect num_output_blks");
|
||||
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
|
||||
"num_regs_blk incorrect");
|
||||
|
||||
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegSizePerXdlops()
|
||||
{
|
||||
return MPerXdlops * NPerXdlops / mfma_type.wave_size;
|
||||
}
|
||||
|
||||
template <class ADesc,
|
||||
class BDesc,
|
||||
class CDesc,
|
||||
index_t m0,
|
||||
index_t n0,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
|
||||
is_same<base_type, ushort>::value,
|
||||
"base base_type must be float, half, ushort!");
|
||||
|
||||
static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base");
|
||||
|
||||
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
|
||||
|
||||
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
|
||||
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
|
||||
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
|
||||
|
||||
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
|
||||
p_a_wave[Number<a_offset / mfma_type.k_base>{}],
|
||||
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
|
||||
p_c_thread);
|
||||
});
|
||||
}
|
||||
|
||||
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
|
||||
{
|
||||
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
|
||||
const index_t blk_id = laneId / mfma_type.num_threads_blk;
|
||||
const index_t blk_td = laneId % mfma_type.num_threads_blk;
|
||||
|
||||
index_t n_offset = blk_i * mfma_type.n + blk_td;
|
||||
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size;
|
||||
|
||||
return CIndex{m_offset, n_offset};
|
||||
}
|
||||
|
||||
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
|
||||
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats;
|
||||
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
|
||||
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
|
||||
|
||||
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
|
||||
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
|
||||
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
|
||||
|
||||
static constexpr auto GetBlkId(const index_t lane_id)
|
||||
{
|
||||
return lane_id / mfma_type.num_threads_blk;
|
||||
}
|
||||
|
||||
static constexpr auto GetBlkTd(const index_t lane_id)
|
||||
{
|
||||
return lane_id % mfma_type.num_threads_blk;
|
||||
}
|
||||
|
||||
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
|
||||
|
||||
struct CLayout
|
||||
{
|
||||
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
|
||||
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
|
||||
__host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
|
||||
__host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
|
||||
|
||||
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
|
||||
|
||||
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
|
||||
|
||||
__device__ static constexpr index_t GetNumXdlops()
|
||||
{
|
||||
return MPerXdlops * NPerXdlops /
|
||||
(mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
|
||||
}
|
||||
};
|
||||
|
||||
__host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
44
composable_kernel/include/utility/amd_address_space.hpp
Normal file
44
composable_kernel/include/utility/amd_address_space.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
#ifndef CK_AMD_ADDRESS_SPACE_HPP
|
||||
#define CK_AMD_ADDRESS_SPACE_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
|
||||
// Address Space for AMDGCN
|
||||
// https://llvm.org/docs/AMDGPUUsage.html#address-space
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpaceEnum_t
|
||||
{
|
||||
Generic,
|
||||
Global,
|
||||
Lds,
|
||||
Sgpr,
|
||||
Vgpr,
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
|
||||
{
|
||||
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
|
||||
// only c-style pointer cast seems be able to be compiled
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
|
||||
{
|
||||
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
|
||||
// only c-style pointer cast seems be able to be compiled
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
return (T CONSTANT*)p; // NOLINT(old-style-cast)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
681
composable_kernel/include/utility/amd_buffer_addressing.hpp
Normal file
681
composable_kernel/include/utility/amd_buffer_addressing.hpp
Normal file
@@ -0,0 +1,681 @@
|
||||
#ifndef CK_AMD_BUFFER_ADDRESSING_HPP
|
||||
#define CK_AMD_BUFFER_ADDRESSING_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
union BufferResource
|
||||
{
|
||||
// 128 bit SGPRs to supply buffer resource in buffer instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
int32x4_t content;
|
||||
StaticallyIndexedArray<T*, 2> address;
|
||||
StaticallyIndexedArray<int32_t, 4> range;
|
||||
StaticallyIndexedArray<int32_t, 4> config;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
|
||||
{
|
||||
BufferResource<T> wave_buffer_resource;
|
||||
|
||||
// wavewise base address (64 bit)
|
||||
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
|
||||
// wavewise range (32 bit)
|
||||
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
|
||||
// wavewise setting (32 bit)
|
||||
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
|
||||
|
||||
return wave_buffer_resource.content;
|
||||
}
|
||||
|
||||
// load
|
||||
__device__ int8_t
|
||||
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
|
||||
|
||||
__device__ int8x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
|
||||
|
||||
__device__ int8x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
|
||||
__device__ int16_t
|
||||
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
__device__ int32_t
|
||||
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
|
||||
__device__ int32x2_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
||||
|
||||
__device__ int32x4_t
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
// half
|
||||
__device__ half_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
|
||||
|
||||
__device__ half2_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
|
||||
|
||||
__device__ half4_t
|
||||
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
|
||||
|
||||
// float
|
||||
__device__ float
|
||||
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
|
||||
|
||||
__device__ float2_t
|
||||
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
|
||||
|
||||
__device__ float4_t
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
|
||||
|
||||
// store
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
// half
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
|
||||
// float
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<float, 8> tmp;
|
||||
|
||||
tmp.AsType<float4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<float4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
0);
|
||||
|
||||
return tmp.AsType<float8_t>()(Number<0>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
#if 0
|
||||
vector_type<half_t, 8> tmp;
|
||||
|
||||
tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<half4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
|
||||
return tmp.AsType<half8_t>()(Number<0>{});
|
||||
#else
|
||||
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<half8_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<int32_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
0);
|
||||
return tmp.AsType<int32x8_t>()(Number<0>{});
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int8_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i8(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return llvm_amdgcn_raw_buffer_load_i8x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
#else
|
||||
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x2_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
#else
|
||||
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x4_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
return tmp.AsType<int8x8_t>()(Number<0>{});
|
||||
#else
|
||||
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x8_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 16> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<2>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<3>{}) =
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
return tmp.AsType<int8x16_t>()(Number<0>{});
|
||||
#else
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x16_t>(tmp);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, float>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int8_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<half_t, 8> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
bool src_thread_element_valid,
|
||||
index_t src_element_space_size)
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
return amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#else
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must be in global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type_maker<T, N>::type::type
|
||||
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
index_t src_thread_element_offset,
|
||||
bool src_thread_element_valid,
|
||||
index_t src_element_space_size,
|
||||
T customized_value)
|
||||
{
|
||||
const int32x4_t src_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_src_wave, src_element_space_size);
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(customized_value);
|
||||
}
|
||||
|
||||
// buffer_store requires:
|
||||
// 1) p_dst_wave must be global memory
|
||||
// 2) p_dst_wave to be a wavewise pointer.
|
||||
// It is user's responsibility to make sure that is true.
|
||||
template <typename T, index_t N>
|
||||
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
using vector_t = typename vector_type_maker<T, N>::type::type;
|
||||
using scalar_t = typename scalar_type<vector_t>::type;
|
||||
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
|
||||
|
||||
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_store_impl<scalar_t, vector_size>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
356
composable_kernel/include/utility/amd_inline_asm.hpp
Normal file
356
composable_kernel/include/utility/amd_inline_asm.hpp
Normal file
@@ -0,0 +1,356 @@
|
||||
#ifndef CK_AMD_INLINE_ASM_HPP
|
||||
#define CK_AMD_INLINE_ASM_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
|
||||
// TODO: deprecate all amd_assembly_outer_product_xxx
|
||||
|
||||
namespace ck {
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %2, %3 \n \
|
||||
v_fmac_f32 %1, %2, %4 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
// c2 += inner_product(a, b2)
|
||||
// c3 += inner_product(a, b3)
|
||||
__device__ void amd_assembly_outer_product_1x4(
|
||||
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %4, %5 \n \
|
||||
v_fmac_f32 %1, %4, %6 \n \
|
||||
v_fmac_f32 %2, %4, %7 \n \
|
||||
v_fmac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
__device__ void
|
||||
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %2, %3, %0\n \
|
||||
v_dot2_f32_f16 %1, %2, %4, %1\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
__device__ void
|
||||
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %2, %4, %0\n \
|
||||
v_dot2_f32_f16 %1, %2, %6, %1\n \
|
||||
v_dot2_f32_f16 %0, %3, %5, %0\n \
|
||||
v_dot2_f32_f16 %1, %3, %7, %1\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(p_a_half2[0]),
|
||||
"v"(p_a_half2[1]),
|
||||
"v"(p_b0_half2[0]),
|
||||
"v"(p_b0_half2[1]),
|
||||
"v"(p_b1_half2[0]),
|
||||
"v"(p_b1_half2[1]),
|
||||
"0"(c0),
|
||||
"1"(c1));
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
// c2 += inner_product(a, b2)
|
||||
// c3 += inner_product(a, b3)
|
||||
__device__ void amd_assembly_outer_product_1x4(half2_t a,
|
||||
half2_t b0,
|
||||
half2_t b1,
|
||||
half2_t b2,
|
||||
half2_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %4, %5, %0\n \
|
||||
v_dot2_f32_f16 %1, %4, %6, %1\n \
|
||||
v_dot2_f32_f16 %2, %4, %7, %2\n \
|
||||
v_dot2_f32_f16 %3, %4, %8, %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
// c2 += inner_product(a, b2)
|
||||
// c3 += inner_product(a, b3)
|
||||
__device__ void amd_assembly_outer_product_1x4(half4_t a,
|
||||
half4_t b0,
|
||||
half4_t b1,
|
||||
half4_t b2,
|
||||
half4_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
|
||||
const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %4, %6, %0\n \
|
||||
v_dot2_f32_f16 %1, %4, %8, %1\n \
|
||||
v_dot2_f32_f16 %2, %4, %10, %2\n \
|
||||
v_dot2_f32_f16 %3, %4, %12, %3\n \
|
||||
v_dot2_f32_f16 %0, %5, %7, %0\n \
|
||||
v_dot2_f32_f16 %1, %5, %9, %1\n \
|
||||
v_dot2_f32_f16 %2, %5, %11, %2\n \
|
||||
v_dot2_f32_f16 %3, %5, %13, %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(p_a_half2[0]),
|
||||
"v"(p_a_half2[1]),
|
||||
"v"(p_b0_half2[0]),
|
||||
"v"(p_b0_half2[1]),
|
||||
"v"(p_b1_half2[0]),
|
||||
"v"(p_b1_half2[1]),
|
||||
"v"(p_b2_half2[0]),
|
||||
"v"(p_b2_half2[1]),
|
||||
"v"(p_b3_half2[0]),
|
||||
"v"(p_b3_half2[1]),
|
||||
"0"(c0),
|
||||
"1"(c1),
|
||||
"2"(c2),
|
||||
"3"(c3));
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_outer_product_1x4(half8_t a,
|
||||
half8_t b0,
|
||||
half8_t b1,
|
||||
half8_t b2,
|
||||
half8_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
|
||||
// TODO remove pointer casting
|
||||
const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
|
||||
const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
|
||||
const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
|
||||
const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
|
||||
const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_outer_product_1x4(half16_t a,
|
||||
half16_t b0,
|
||||
half16_t b1,
|
||||
half16_t b2,
|
||||
half16_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
// TODO remove pointer casting
|
||||
const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
|
||||
const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
|
||||
const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
|
||||
const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
|
||||
const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
__device__ void
|
||||
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
|
||||
{
|
||||
#if 1
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %2, %3, %0\n \
|
||||
v_dot4_i32_i8 %1, %2, %4, %1\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(as_type<int32_t>(a)),
|
||||
"v"(as_type<int32_t>(b0)),
|
||||
"v"(as_type<int32_t>(b1)),
|
||||
"0"(c0),
|
||||
"1"(c1));
|
||||
#else
|
||||
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
// c2 += inner_product(a, b2)
|
||||
// c3 += inner_product(a, b3)
|
||||
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
|
||||
int8x4_t b0,
|
||||
int8x4_t b1,
|
||||
int8x4_t b2,
|
||||
int8x4_t b3,
|
||||
int32_t& c0,
|
||||
int32_t& c1,
|
||||
int32_t& c2,
|
||||
int32_t& c3)
|
||||
{
|
||||
#if 1
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %4, %5, %0\n \
|
||||
v_dot4_i32_i8 %1, %4, %6, %1\n \
|
||||
v_dot4_i32_i8 %2, %4, %7, %2\n \
|
||||
v_dot4_i32_i8 %3, %4, %8, %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(as_type<int32_t>(a)),
|
||||
"v"(as_type<int32_t>(b0)),
|
||||
"v"(as_type<int32_t>(b1)),
|
||||
"v"(as_type<int32_t>(b2)),
|
||||
"v"(as_type<int32_t>(b3)),
|
||||
"0"(c0),
|
||||
"1"(c1),
|
||||
"2"(c2),
|
||||
"3"(c3));
|
||||
#else
|
||||
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
|
||||
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
|
||||
c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
|
||||
c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
|
||||
int8x8_t b0,
|
||||
int8x8_t b1,
|
||||
int8x8_t b2,
|
||||
int8x8_t b3,
|
||||
int32_t& c0,
|
||||
int32_t& c1,
|
||||
int32_t& c2,
|
||||
int32_t& c3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
|
||||
int8x16_t b0,
|
||||
int8x16_t b1,
|
||||
int8x16_t b2,
|
||||
int8x16_t b3,
|
||||
int32_t& c0,
|
||||
int32_t& c1,
|
||||
int32_t& c2,
|
||||
int32_t& c3)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
11
composable_kernel/include/utility/amd_llvm_intrinsic.hpp
Normal file
11
composable_kernel/include/utility/amd_llvm_intrinsic.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
|
||||
#define CK_AMD_LLVM_INTRINSIC_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
499
composable_kernel/include/utility/amd_xdlops.hpp
Normal file
499
composable_kernel/include/utility/amd_xdlops.hpp
Normal file
@@ -0,0 +1,499 @@
|
||||
#ifndef CK_AMD_XDLOPS_HPP
|
||||
#define CK_AMD_XDLOPS_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A, B, C, cbsz, abid, blgp
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
||||
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
||||
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
|
||||
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
||||
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
||||
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
|
||||
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
|
||||
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
|
||||
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
1,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x2f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x4f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
2,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x1f32;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
1,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
1,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
||||
1,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x8f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x16f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
||||
2,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x4f16;
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t COffset>
|
||||
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
0,
|
||||
0);
|
||||
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
|
||||
4,
|
||||
1,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16;
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_4_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
reg_c.s.z =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
|
||||
reg_c.s.w =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_4_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
reg_c.s.z =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
|
||||
reg_c.s.w =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_2_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
|
||||
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec4_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c);
|
||||
|
||||
template <>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
|
||||
return reg_c;
|
||||
}
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x2bf16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x2bf16<4, 64>
|
||||
{
|
||||
__device__ static c_vec4_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x2bf16<8, 64>
|
||||
{
|
||||
__device__ static c_vec4_2_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
63
composable_kernel/include/utility/array.hpp
Normal file
63
composable_kernel/include/utility/array.hpp
Normal file
@@ -0,0 +1,63 @@
|
||||
#ifndef CK_ARRAY_HPP
|
||||
#define CK_ARRAY_HPP
|
||||
|
||||
#include "functional2.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
struct Array
|
||||
{
|
||||
using type = Array;
|
||||
using data_type = TData;
|
||||
|
||||
TData mData[NSize];
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return NSize; }
|
||||
|
||||
__host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; }
|
||||
|
||||
__host__ __device__ constexpr TData& At(index_t i) { return mData[i]; }
|
||||
|
||||
__host__ __device__ constexpr const TData& operator[](index_t i) const { return At(i); }
|
||||
|
||||
__host__ __device__ constexpr TData& operator()(index_t i) { return At(i); }
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto operator=(const T& a)
|
||||
{
|
||||
static_assert(T::Size() == Size(), "wrong! size not the same");
|
||||
|
||||
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
|
||||
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
// empty Array
|
||||
template <typename TData>
|
||||
struct Array<TData, 0>
|
||||
{
|
||||
using type = Array;
|
||||
using data_type = TData;
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 0; }
|
||||
};
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
|
||||
{
|
||||
using data_type = remove_cv_t<remove_reference_t<X>>;
|
||||
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}};
|
||||
}
|
||||
|
||||
// make empty array
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto make_array()
|
||||
{
|
||||
return Array<X, 0>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
77
composable_kernel/include/utility/array_multi_index.hpp
Normal file
77
composable_kernel/include/utility/array_multi_index.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
#ifndef CK_ARRAY_MULTI_INDEX_HPP
|
||||
#define CK_ARRAY_MULTI_INDEX_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
using MultiIndex = Array<index_t, N>;
|
||||
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
|
||||
{
|
||||
return make_array<index_t>(index_t{xs}...);
|
||||
}
|
||||
|
||||
template <index_t NSize>
|
||||
__host__ __device__ constexpr auto make_zero_multi_index()
|
||||
{
|
||||
return unpack([](auto... xs) { return make_multi_index(xs...); },
|
||||
typename uniform_sequence_gen<NSize, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto to_multi_index(const T& x)
|
||||
{
|
||||
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
__host__ __device__ constexpr auto operator+=(MultiIndex<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename X>
|
||||
__host__ __device__ constexpr auto operator-=(MultiIndex<NSize>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == NSize, "wrong! size not the same");
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
__host__ __device__ constexpr auto operator+(const MultiIndex<NSize>& a, const T& b)
|
||||
{
|
||||
using type = MultiIndex<NSize>;
|
||||
static_assert(T::Size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
__host__ __device__ constexpr auto operator-(const MultiIndex<NSize>& a, const T& b)
|
||||
{
|
||||
using type = MultiIndex<NSize>;
|
||||
static_assert(T::Size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <index_t NSize, typename T>
|
||||
__host__ __device__ constexpr auto operator*(const MultiIndex<NSize>& a, const T& b)
|
||||
{
|
||||
using type = MultiIndex<NSize>;
|
||||
static_assert(T::Size() == NSize, "wrong! size not the same");
|
||||
type r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
22
composable_kernel/include/utility/c_style_pointer_cast.hpp
Normal file
22
composable_kernel/include/utility/c_style_pointer_cast.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#ifndef CK_C_STYLE_POINTER_CAST_HPP
|
||||
#define CK_C_STYLE_POINTER_CAST_HPP
|
||||
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename PY,
|
||||
typename PX,
|
||||
typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
|
||||
__host__ __device__ PY c_style_pointer_cast(PX p_x)
|
||||
{
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
#pragma clang diagnostic ignored "-Wcast-align"
|
||||
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
46
composable_kernel/include/utility/common_header.hpp
Normal file
46
composable_kernel/include/utility/common_header.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef CK_COMMON_HEADER_HPP
|
||||
#define CK_COMMON_HEADER_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "array.hpp"
|
||||
#include "container_helper.hpp"
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_element_picker.hpp"
|
||||
#include "multi_index.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "data_type_enum.hpp"
|
||||
#include "data_type_enum_helper.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
#include "functional4.hpp"
|
||||
#include "enable_if.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "math.hpp"
|
||||
#include "number.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "sequence_helper.hpp"
|
||||
#include "synchronization.hpp"
|
||||
#include "tuple.hpp"
|
||||
#include "tuple_helper.hpp"
|
||||
#include "type.hpp"
|
||||
#include "magic_division.hpp"
|
||||
#include "utility.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "amd_address_space.hpp"
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "static_buffer.hpp"
|
||||
#include "dynamic_buffer.hpp"
|
||||
|
||||
#include "inner_product.hpp"
|
||||
|
||||
// TODO: remove this
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
#endif
|
||||
|
||||
#if CK_USE_AMD_XDLOPS
|
||||
#include "amd_xdlops.hpp"
|
||||
#endif
|
||||
|
||||
#endif
|
||||
134
composable_kernel/include/utility/config.hpp
Normal file
134
composable_kernel/include/utility/config.hpp
Normal file
@@ -0,0 +1,134 @@
|
||||
#ifndef CK_CONFIG_AMD_HPP
|
||||
#define CK_CONFIG_AMD_HPP
|
||||
|
||||
#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#endif
|
||||
#include "bfloat16_dev.hpp"
|
||||
|
||||
// "Constant" address space for kernel parameter
|
||||
#define CONSTANT __attribute__((address_space(4)))
|
||||
|
||||
// GPU target
|
||||
// should enable one and only one GPU target
|
||||
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
|
||||
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
|
||||
#error Need to define (only) one GPU target
|
||||
#endif
|
||||
|
||||
// launch bounds
|
||||
#define CK_USE_LAUNCH_BOUNDS 1
|
||||
|
||||
#ifdef CK_USE_LAUNCH_BOUNDS
|
||||
#define CK_MAX_THREAD_PER_BLOCK 256
|
||||
#define CK_MIN_BLOCK_PER_CU 2
|
||||
#endif
|
||||
|
||||
// buffer resourse
|
||||
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
|
||||
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(CK_AMD_GPU_GFX1030)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
#endif
|
||||
|
||||
// FMA instruction
|
||||
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900)
|
||||
#define CK_USE_AMD_V_MAC_F32
|
||||
#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \
|
||||
defined(CK_AMD_GPU_GFX1030)
|
||||
#define CK_USE_AMD_V_FMAC_F32
|
||||
#define CK_USE_AMD_V_DOT2_F32_F16
|
||||
#define CK_USE_AMD_V_DOT4_I32_I8
|
||||
#endif
|
||||
|
||||
// multi index
|
||||
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
|
||||
|
||||
// AMD inline asm
|
||||
#ifndef CK_USE_AMD_INLINE_ASM
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
// AMD inner product (DLOP)
|
||||
#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
// AMD buffer addressing
|
||||
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING 1
|
||||
#endif
|
||||
|
||||
// only gfx908 support native floating point atomic add
|
||||
#ifndef CK_USE_AMD_BUFFER_ATOMIC_FADD
|
||||
#define CK_USE_AMD_BUFFER_ATOMIC_FADD 0
|
||||
#endif
|
||||
|
||||
// AMD XDLOPS
|
||||
#ifndef CK_USE_AMD_XDLOPS
|
||||
#define CK_USE_AMD_XDLOPS 0
|
||||
#endif
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
#endif
|
||||
|
||||
// experimental implementation
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
// pass tensor descriptor by value or void*
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
|
||||
|
||||
// merge transformation use magic number division
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
|
||||
|
||||
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
|
||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||
// thread-invariant, otherwise it's a bug
|
||||
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
|
||||
#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
|
||||
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when compiling recursive lambda
|
||||
#ifndef CK_WORKAROUND_SWDEV_275126
|
||||
#define CK_WORKAROUND_SWDEV_275126 1
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when using buffer load/store for i8
|
||||
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when using buffer load/store for i8
|
||||
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum InMemoryDataOperationEnum_t
|
||||
{
|
||||
Set,
|
||||
AtomicAdd
|
||||
};
|
||||
|
||||
// index type
|
||||
using index_t = int32_t;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
155
composable_kernel/include/utility/container_element_picker.hpp
Normal file
155
composable_kernel/include/utility/container_element_picker.hpp
Normal file
@@ -0,0 +1,155 @@
|
||||
#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP
|
||||
#define CK_CONTAINER_ELEMENT_PICKER_HPP
|
||||
|
||||
#include "functional2.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Arr: Array or StaticallyIndexedArray
|
||||
// Picks: Sequence<...>
|
||||
template <typename Arr, typename Picks>
|
||||
struct ContainerElementPicker
|
||||
{
|
||||
using type = ContainerElementPicker;
|
||||
#if 0
|
||||
using data_type = typename Arr::data_type;
|
||||
#endif
|
||||
|
||||
__host__ __device__ constexpr ContainerElementPicker() = delete;
|
||||
|
||||
__host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array}
|
||||
{
|
||||
constexpr index_t imax =
|
||||
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
|
||||
|
||||
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& At(Number<I> i) const
|
||||
{
|
||||
static_assert(I < Size(), "wrong!");
|
||||
|
||||
constexpr auto IP = Picks{}[i];
|
||||
return mArray[IP];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& At(Number<I> i)
|
||||
{
|
||||
static_assert(I < Size(), "wrong!");
|
||||
|
||||
constexpr auto IP = Picks{}[i];
|
||||
return mArray(IP);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto operator=(const T& a)
|
||||
{
|
||||
static_assert(T::Size() == Size(), "wrong! size not the same");
|
||||
|
||||
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
private:
|
||||
Arr& mArray;
|
||||
};
|
||||
|
||||
// Arr: Array or StaticallyIndexedArray
|
||||
// Picks: Sequence<...>
|
||||
template <typename Arr, typename Picks>
|
||||
struct ConstantContainerElementPicker
|
||||
{
|
||||
using type = ConstantContainerElementPicker;
|
||||
#if 0
|
||||
using data_type = typename Arr::data_type;
|
||||
#endif
|
||||
|
||||
__host__ __device__ constexpr ConstantContainerElementPicker() = delete;
|
||||
|
||||
__host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array}
|
||||
{
|
||||
constexpr index_t imax =
|
||||
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
|
||||
|
||||
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& At(Number<I> i) const
|
||||
{
|
||||
static_assert(I < Size(), "wrong!");
|
||||
|
||||
constexpr auto IP = Picks{}[i];
|
||||
return mArray[IP];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
private:
|
||||
const Arr& mArray;
|
||||
};
|
||||
|
||||
template <typename Arr, typename Picks, typename X>
|
||||
__host__ __device__ constexpr auto operator+=(ContainerElementPicker<Arr, Picks>& y, const X& x)
|
||||
{
|
||||
using Y = ContainerElementPicker<Arr, Picks>;
|
||||
constexpr index_t nsize = Y::Size();
|
||||
|
||||
static_assert(nsize == X::Size(), "wrong! size not the same");
|
||||
|
||||
static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; });
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename Arr, typename Picks, typename X>
|
||||
__host__ __device__ constexpr auto operator-=(ContainerElementPicker<Arr, Picks>& y, const X& x)
|
||||
{
|
||||
using Y = ContainerElementPicker<Arr, Picks>;
|
||||
constexpr index_t nsize = Y::Size();
|
||||
|
||||
static_assert(nsize == X::Size(), "wrong! size not the same");
|
||||
|
||||
static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; });
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename Arr, typename Picks>
|
||||
__host__ __device__ constexpr auto pick_container_element(Arr& a, Picks)
|
||||
{
|
||||
return ContainerElementPicker<Arr, Picks>(a);
|
||||
}
|
||||
|
||||
template <typename Arr, typename Picks>
|
||||
__host__ __device__ constexpr auto pick_container_element(const Arr& a, Picks)
|
||||
{
|
||||
return ConstantContainerElementPicker<Arr, Picks>(a);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
403
composable_kernel/include/utility/container_helper.hpp
Normal file
403
composable_kernel/include/utility/container_helper.hpp
Normal file
@@ -0,0 +1,403 @@
|
||||
#ifndef CK_CONTAINER_HELPER_HPP
|
||||
#define CK_CONTAINER_HELPER_HPP
|
||||
|
||||
#include "sequence.hpp"
|
||||
#include "sequence_helper.hpp"
|
||||
#include "array.hpp"
|
||||
#include "tuple.hpp"
|
||||
#include "tuple_helper.hpp"
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_element_picker.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename TData, index_t NSize>
|
||||
__host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>& a, const TData& x)
|
||||
{
|
||||
Array<TData, NSize + 1> r;
|
||||
|
||||
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
|
||||
|
||||
r(Number<NSize>{}) = x;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
__host__ __device__ constexpr auto container_push_front(const Tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(make_tuple(x), a);
|
||||
}
|
||||
|
||||
template <typename... Ts, typename T>
|
||||
__host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x)
|
||||
{
|
||||
return container_concat(a, make_tuple(x));
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reorder_given_new2old(const Array<TData, NSize>& old_array, Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return make_array(old_array[Number<IRs>{}]...);
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reorder_given_old2new(const Array<TData, NSize>& old_array, Sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... IRs>
|
||||
__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple<Ts...>& old_tuple,
|
||||
Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return make_tuple(old_tuple[Number<IRs>{}]...);
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... IRs>
|
||||
__host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<Ts...>& old_tuple,
|
||||
Sequence<IRs...> old2new)
|
||||
{
|
||||
return container_reorder_given_new2old(
|
||||
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
|
||||
}
|
||||
|
||||
template <index_t... Is, index_t... IRs>
|
||||
__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
|
||||
Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
return Sequence<Sequence<Is...>::At(Number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is, index_t... IRs>
|
||||
__host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is...> old_seq,
|
||||
Sequence<IRs...> /* old2new */)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
|
||||
|
||||
constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{};
|
||||
|
||||
return container_reorder_given_new2old(old_seq, new2old);
|
||||
}
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_275126
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::Size(),
|
||||
index_t IStep = 1>
|
||||
__host__ __device__ constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
Number<IBegin> = Number<0>{},
|
||||
Number<IEnd> = Number<Container::Size()>{},
|
||||
Number<IStep> = Number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i + Number<IStep>{}, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, Number<IBegin>{}, init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename ROld,
|
||||
index_t I,
|
||||
index_t IEnd,
|
||||
index_t IStep>
|
||||
__host__ __device__ constexpr auto container_reduce_impl(
|
||||
const Container& x, Reduce reduce, ROld r_old, Number<I> i, Number<IEnd>, Number<IStep>)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
if constexpr(i.value < IEnd - IStep)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, r_new, i + Number<IStep>{}, Number<IEnd>{}, Number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return r_new;
|
||||
}
|
||||
}
|
||||
|
||||
// rocm-4.1 compiler would crash for recursive lambda
|
||||
// container reduce with initial value
|
||||
template <typename Container,
|
||||
typename Reduce,
|
||||
typename Init,
|
||||
index_t IBegin = 0,
|
||||
index_t IEnd = Container::Size(),
|
||||
index_t IStep = 1>
|
||||
__host__ __device__ constexpr auto container_reduce(const Container& x,
|
||||
Reduce reduce,
|
||||
Init init,
|
||||
Number<IBegin> = Number<0>{},
|
||||
Number<IEnd> = Number<Container::Size()>{},
|
||||
Number<IStep> = Number<1>{})
|
||||
{
|
||||
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
|
||||
|
||||
if constexpr(IEnd > IBegin)
|
||||
{
|
||||
return container_reduce_impl(
|
||||
x, reduce, init, Number<IBegin>{}, Number<IEnd>{}, Number<IStep>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return init;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reverse_inclusive_scan(const Array<TData, NSize>& x, Reduce f, TData init)
|
||||
{
|
||||
Array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[Number<0>{}]);
|
||||
y(Number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename TData, index_t NSize, typename Reduce>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData init)
|
||||
{
|
||||
Array<TData, NSize> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
y(i) = r;
|
||||
r = f(r, x[i]);
|
||||
});
|
||||
|
||||
y(Number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t... Is, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reverse_exclusive_scan(const Sequence<Is...>& seq, Reduce f, Number<Init>)
|
||||
{
|
||||
return reverse_exclusive_scan_sequence(seq, f, Number<Init>{});
|
||||
}
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_275126
|
||||
// rocm4.1 compiler would crash with recursive lambda
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
// f is recursive function, fs is a dummy of f
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return fs(fs, i - Number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
};
|
||||
|
||||
// start recursion
|
||||
return f(f, Number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#else
|
||||
// i is index, y_old is current scan, r_old is current reduction
|
||||
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
|
||||
__host__ __device__ constexpr auto container_reverse_exclusive_scan_impl(
|
||||
const Tuple<Xs...>& x, Reduce reduce, Number<I> i, YOld y_old, ROld r_old)
|
||||
{
|
||||
auto r_new = reduce(x[i], r_old);
|
||||
|
||||
auto y_new = container_push_front(y_old, r_new);
|
||||
|
||||
if constexpr(i.value > 1)
|
||||
{
|
||||
// recursively call f/fs
|
||||
return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new);
|
||||
}
|
||||
else
|
||||
{
|
||||
return y_new;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce reduce, Init init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
return container_reverse_exclusive_scan_impl(
|
||||
x, reduce, Number<NSize - 1>{}, make_tuple(init), init);
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<>
|
||||
template <typename... Xs, typename Reduce, typename TData>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
Tuple<Xs...> y;
|
||||
|
||||
TData r = init;
|
||||
|
||||
static_for<NSize - 1, 0, -1>{}([&](auto i) {
|
||||
r = f(r, x[i]);
|
||||
y(i) = r;
|
||||
});
|
||||
|
||||
r = f(r, x[Number<0>{}]);
|
||||
y(Number<0>{}) = r;
|
||||
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys)
|
||||
{
|
||||
return container_concat(x, container_concat(ys...));
|
||||
}
|
||||
|
||||
template <typename T, index_t NX, index_t NY>
|
||||
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay);
|
||||
}
|
||||
|
||||
template <typename... X, typename... Y>
|
||||
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
||||
{
|
||||
return unpack2(
|
||||
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
__host__ __device__ constexpr auto container_concat(const Container& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
__host__ __device__ constexpr auto get_container_subset(const Array<T, N>& arr, Sequence<Is...>)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
return make_array(arr[Number<Is>{}]...);
|
||||
}
|
||||
|
||||
template <typename... Ts, index_t... Is>
|
||||
__host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup, Sequence<Is...>)
|
||||
{
|
||||
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
|
||||
|
||||
return make_tuple(tup[Number<Is>{}]...);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, index_t... Is>
|
||||
__host__ __device__ constexpr void
|
||||
set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
|
||||
{
|
||||
static_assert(N >= sizeof...(Is), "wrong! size");
|
||||
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
|
||||
template <typename... Ys, index_t... Is, typename... Xs>
|
||||
__host__ __device__ constexpr void
|
||||
set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>& x)
|
||||
{
|
||||
static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size");
|
||||
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
__host__ __device__ constexpr auto to_tuple_of_number(const Container&)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<Container>::value, "wrong!");
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t tmp = Container::At(i);
|
||||
return Number<tmp>{};
|
||||
},
|
||||
Container::Size());
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
|
||||
{
|
||||
using Seq = Sequence<Is...>;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t tmp = Seq::At(i);
|
||||
return Number<tmp>{};
|
||||
},
|
||||
Seq::Size());
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
1017
composable_kernel/include/utility/data_type.hpp
Normal file
1017
composable_kernel/include/utility/data_type.hpp
Normal file
File diff suppressed because it is too large
Load Diff
19
composable_kernel/include/utility/data_type_enum.hpp
Normal file
19
composable_kernel/include/utility/data_type_enum.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef CK_DATA_TYPE_ENUM_HPP
|
||||
#define CK_DATA_TYPE_ENUM_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum DataTypeEnum_t
|
||||
{
|
||||
Half = 0,
|
||||
Float = 1,
|
||||
Int32 = 2,
|
||||
Int8 = 3,
|
||||
Int8x4 = 4,
|
||||
BFloat16 = 5,
|
||||
Double = 6,
|
||||
Unknown = 100,
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
76
composable_kernel/include/utility/data_type_enum_helper.hpp
Normal file
76
composable_kernel/include/utility/data_type_enum_helper.hpp
Normal file
@@ -0,0 +1,76 @@
|
||||
#ifndef CK_DATA_TYPE_ENUM_HELPER_HPP
|
||||
#define CK_DATA_TYPE_ENUM_HELPER_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "data_type_enum.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <DataTypeEnum_t DataTypeEnum>
|
||||
struct get_datatype_from_enum;
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Int8>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Int32>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Half>
|
||||
{
|
||||
using type = half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Float>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Double>
|
||||
{
|
||||
using type = double;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct get_datatype_enum_from_type;
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<int8_t>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<int32_t>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<half_t>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<float>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<double>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
246
composable_kernel/include/utility/dynamic_buffer.hpp
Normal file
246
composable_kernel/include/utility/dynamic_buffer.hpp
Normal file
@@ -0,0 +1,246 @@
|
||||
#ifndef CK_BUFFER_HPP
|
||||
#define CK_BUFFER_HPP
|
||||
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
typename ElementSpaceSize,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct DynamicBuffer
|
||||
{
|
||||
using type = T;
|
||||
|
||||
T* p_data_;
|
||||
ElementSpaceSize element_space_size_;
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
|
||||
: p_data_{p_data}, element_space_size_{element_space_size}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr DynamicBuffer(T* p_data,
|
||||
ElementSpaceSize element_space_size,
|
||||
T invalid_element_value)
|
||||
: p_data_{p_data},
|
||||
element_space_size_{element_space_size},
|
||||
invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
bool constexpr use_amd_buffer_addressing = true;
|
||||
#else
|
||||
bool constexpr use_amd_buffer_addressing = false;
|
||||
#endif
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
|
||||
{
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_return_zero<
|
||||
remove_cv_t<remove_reference_t<T>>,
|
||||
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return amd_buffer_load_invalid_element_return_customized_value<
|
||||
remove_cv_t<remove_reference_t<T>>,
|
||||
t_per_x>(
|
||||
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
|
||||
: X{invalid_element_value_};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector =
|
||||
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X need to be multiple T");
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
x, p_data_, i, is_valid_element, element_space_size_);
|
||||
#else
|
||||
if(is_valid_element)
|
||||
{
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
#else
|
||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
|
||||
// inefficient
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
|
||||
int8_t>::value)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int8_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int16_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x4_t*>(&x);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
||||
{
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto
|
||||
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
|
||||
{
|
||||
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
|
||||
p, element_space_size, invalid_element_value};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
13
composable_kernel/include/utility/enable_if.hpp
Normal file
13
composable_kernel/include/utility/enable_if.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef CK_ENABLE_IF_HPP
|
||||
#define CK_ENABLE_IF_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <bool B, typename T = void>
|
||||
using enable_if = std::enable_if<B, T>;
|
||||
|
||||
template <bool B, typename T = void>
|
||||
using enable_if_t = typename std::enable_if<B, T>::type;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
116
composable_kernel/include/utility/functional.hpp
Normal file
116
composable_kernel/include/utility/functional.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
#ifndef CK_FUNCTIONAL_HPP
|
||||
#define CK_FUNCTIONAL_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// TODO: right? wrong?
|
||||
struct forwarder
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T&& operator()(T&& x) const
|
||||
{
|
||||
return static_cast<T&&>(x);
|
||||
}
|
||||
};
|
||||
|
||||
struct swallow
|
||||
{
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr swallow(Ts&&...)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct logical_and
|
||||
{
|
||||
constexpr bool operator()(const T& x, const T& y) const { return x && y; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct logical_or
|
||||
{
|
||||
constexpr bool operator()(const T& x, const T& y) const { return x || y; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct logical_not
|
||||
{
|
||||
constexpr bool operator()(const T& x) const { return !x; }
|
||||
};
|
||||
|
||||
// Emulate if constexpr
|
||||
template <bool>
|
||||
struct static_if;
|
||||
|
||||
template <>
|
||||
struct static_if<true>
|
||||
{
|
||||
using Type = static_if<true>;
|
||||
|
||||
template <typename F>
|
||||
__host__ __device__ constexpr auto operator()(F f) const
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
// Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will
|
||||
// use it,
|
||||
// this will make "f" a generic lambda, so that "f" won't be compiled
|
||||
// until being
|
||||
// instantiated here
|
||||
f(forwarder{});
|
||||
return Type{};
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__host__ __device__ static void Else(F)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_if<false>
|
||||
{
|
||||
using Type = static_if<false>;
|
||||
|
||||
template <typename F>
|
||||
__host__ __device__ constexpr auto operator()(F) const
|
||||
{
|
||||
return Type{};
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__host__ __device__ static void Else(F f)
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
// Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will
|
||||
// use it,
|
||||
// this will make "f" a generic lambda, so that "f" won't be compiled
|
||||
// until being
|
||||
// instantiated here
|
||||
f(forwarder{});
|
||||
}
|
||||
};
|
||||
|
||||
template <bool predicate, class X, class Y>
|
||||
struct conditional;
|
||||
|
||||
template <class X, class Y>
|
||||
struct conditional<true, X, Y>
|
||||
{
|
||||
using type = X;
|
||||
};
|
||||
|
||||
template <class X, class Y>
|
||||
struct conditional<false, X, Y>
|
||||
{
|
||||
using type = Y;
|
||||
};
|
||||
|
||||
template <bool predicate, class X, class Y>
|
||||
using conditional_t = typename conditional<predicate, X, Y>::type;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
48
composable_kernel/include/utility/functional2.hpp
Normal file
48
composable_kernel/include/utility/functional2.hpp
Normal file
@@ -0,0 +1,48 @@
|
||||
#ifndef CK_FUNCTIONAL2_HPP
|
||||
#define CK_FUNCTIONAL2_HPP
|
||||
|
||||
#include "functional.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class>
|
||||
struct static_for_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct static_for_impl<Sequence<Is...>>
|
||||
{
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
swallow{(f(Number<Is>{}), 0)...};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// F signature: F(Number<Iter>)
|
||||
template <index_t NBegin, index_t NEnd, index_t Increment>
|
||||
struct static_for
|
||||
{
|
||||
__host__ __device__ constexpr static_for()
|
||||
{
|
||||
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
|
||||
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
|
||||
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
|
||||
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
|
||||
"NBegin >= NEnd)");
|
||||
}
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
|
||||
f);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
142
composable_kernel/include/utility/functional3.hpp
Normal file
142
composable_kernel/include/utility/functional3.hpp
Normal file
@@ -0,0 +1,142 @@
|
||||
#ifndef CK_FUNCTIONAL3_HPP
|
||||
#define CK_FUNCTIONAL3_HPP
|
||||
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "multi_index.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
// Orders: Sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct static_ford_impl
|
||||
{
|
||||
__host__ __device__ constexpr static_ford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
// F signature: F(Sequence<...>)
|
||||
// CurrentOrderedId: Sequence<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
|
||||
{
|
||||
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
|
||||
f, CurrentOrderedId::PushBack(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct static_ford_impl<Sequence<>, Orders>
|
||||
{
|
||||
// F signature: F(Sequence<...>)
|
||||
// OrderedId: Sequence<...>
|
||||
template <class F, class OrderedId>
|
||||
__host__ __device__ constexpr void operator()(F f, OrderedId) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(OrderedId::ReorderGivenOld2New(Orders{}));
|
||||
}
|
||||
};
|
||||
|
||||
// RemainLengths: Sequence<...>
|
||||
// Orders: Sequence<...>
|
||||
template <class RemainLengths, class Orders>
|
||||
struct ford_impl
|
||||
{
|
||||
__host__ __device__ constexpr ford_impl()
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
|
||||
}
|
||||
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentOrderdId: Array<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
|
||||
{
|
||||
for(index_t i = 0; i < RemainLengths::Front(); ++i)
|
||||
{
|
||||
ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
|
||||
f, container_push_back(current_ordered_id, i));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class Orders>
|
||||
struct ford_impl<Sequence<>, Orders>
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentOrderdId: Array<...>
|
||||
template <class F, class CurrentOrderedId>
|
||||
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
|
||||
{
|
||||
// retrive unordered Id
|
||||
f(container_reorder_given_old2new(current_ordered_id, Orders{}));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Lengths is Sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is Sequence<...>, it is the order of dimension in which static_ford
|
||||
// will loop over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
|
||||
struct static_ford
|
||||
{
|
||||
__host__ __device__ constexpr static_ford()
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
|
||||
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>, it is the length of each dimension for
|
||||
// N-dimensional loop
|
||||
// Orders is Sequence<...>, it is the order of dimension in which ford will loop
|
||||
// over each
|
||||
// dimension
|
||||
template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
|
||||
struct ford
|
||||
{
|
||||
__host__ __device__ constexpr ford()
|
||||
{
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
|
||||
}
|
||||
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// multi_id is the unordered multi-index
|
||||
template <class F>
|
||||
__host__ __device__ constexpr void operator()(F f) const
|
||||
{
|
||||
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
|
||||
|
||||
for(index_t i = 0; i < ordered_lengths.Front(); ++i)
|
||||
{
|
||||
detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
|
||||
make_multi_index(i));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
62
composable_kernel/include/utility/functional4.hpp
Normal file
62
composable_kernel/include/utility/functional4.hpp
Normal file
@@ -0,0 +1,62 @@
|
||||
#ifndef CK_FUNCTIONAL4_HPP
|
||||
#define CK_FUNCTIONAL4_HPP
|
||||
|
||||
#include "sequence.hpp"
|
||||
#include "tuple.hpp"
|
||||
#include "array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename Indices>
|
||||
struct unpack_impl;
|
||||
|
||||
template <index_t... Is>
|
||||
struct unpack_impl<Sequence<Is...>>
|
||||
{
|
||||
template <typename F, typename X>
|
||||
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const
|
||||
{
|
||||
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Seq0, typename Seq1>
|
||||
struct unpack2_impl;
|
||||
|
||||
// TODO: remove this, after properly implementing unpack that takes any number of containers
|
||||
template <index_t... Is, index_t... Js>
|
||||
struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
|
||||
{
|
||||
template <typename F, typename X, typename Y>
|
||||
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
|
||||
{
|
||||
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...,
|
||||
std::forward<Y>(y).At(Number<Js>{})...);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
__host__ __device__ constexpr auto unpack(F&& f, X&& x)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x));
|
||||
}
|
||||
|
||||
// TODO: properly implement unpack that takes any number of containers
|
||||
template <typename F, typename X, typename Y>
|
||||
__host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
|
||||
{
|
||||
using X_ = remove_reference_t<X>;
|
||||
using Y_ = remove_reference_t<Y>;
|
||||
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
|
||||
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
|
||||
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
207
composable_kernel/include/utility/inner_product.hpp
Normal file
207
composable_kernel/include/utility/inner_product.hpp
Normal file
@@ -0,0 +1,207 @@
|
||||
#ifndef CK_INNER_PRODUCT_HPP
|
||||
#define CK_INNER_PRODUCT_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
__device__ void inner_product(const TA& a, const TB& b, TC& c);
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c += a * b;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
inner_product(vector_type<float, 2>{a}.AsType<float>()[I0],
|
||||
vector_type<float, 2>{b}.AsType<float>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<float, 2>{a}.AsType<float>()[I1],
|
||||
vector_type<float, 2>{b}.AsType<float>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
inner_product(vector_type<float, 4>{a}.AsType<float>()[I0],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<float, 4>{a}.AsType<float>()[I1],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I1],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<float, 4>{a}.AsType<float>()[I2],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I2],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<float, 4>{a}.AsType<float>()[I3],
|
||||
vector_type<float, 4>{b}.AsType<float>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
|
||||
{
|
||||
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot2(a, b, c, false);
|
||||
#endif
|
||||
#else
|
||||
const auto convert = type_convert<int32_t>{};
|
||||
|
||||
const vector_type<half_t, 2> a_vector{a};
|
||||
const vector_type<half_t, 2> b_vector{b};
|
||||
|
||||
static_for<0, 2, 1>{}([&](auto i) {
|
||||
c += convert(a_vector.AsType<half_t>()[i]) * convert(b_vector.AsType<half_t>()[i]);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
|
||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
|
||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
|
||||
{
|
||||
#if defined(CK_USE_DOT4_I32_I8)
|
||||
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
||||
#endif
|
||||
#else
|
||||
const auto convert = type_convert<int32_t>{};
|
||||
|
||||
const vector_type<int8_t, 4> a_vector{a};
|
||||
const vector_type<int8_t, 4> b_vector{b};
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
c += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a, const int8x8_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
|
||||
c);
|
||||
|
||||
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
17
composable_kernel/include/utility/integral_constant.hpp
Normal file
17
composable_kernel/include/utility/integral_constant.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
#ifndef CK_INTEGRAL_CONSTANT_HPP
|
||||
#define CK_INTEGRAL_CONSTANT_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class T, T v>
|
||||
struct integral_constant
|
||||
{
|
||||
static constexpr T value = v;
|
||||
typedef T value_type;
|
||||
typedef integral_constant type;
|
||||
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
|
||||
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
155
composable_kernel/include/utility/magic_division.hpp
Normal file
155
composable_kernel/include/utility/magic_division.hpp
Normal file
@@ -0,0 +1,155 @@
|
||||
#ifndef CK_MAGIC_DIVISION_HPP
|
||||
#define CK_MAGIC_DIVISION_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
#include "tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// magic number division
|
||||
// Caution:
|
||||
// 1. For uint32_t as dividend: magic number division implementation being used would produce
|
||||
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
|
||||
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
|
||||
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
|
||||
// division implementation for uint32_t is then used. Therefore, dividend value need to be
|
||||
// non-negative.
|
||||
// TODO:
|
||||
// 1. Implement magic number divison for int32_t
|
||||
// 2. Implement magic number divison for unit32_t with 32-bit value range
|
||||
struct MagicDivision
|
||||
{
|
||||
// uint32_t
|
||||
__host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor)
|
||||
{
|
||||
// assert(divisior >= 1 && divisior <= INT32_MAX);
|
||||
uint32_t shift = 0;
|
||||
for(shift = 0; shift < 32; ++shift)
|
||||
{
|
||||
if((1U << shift) >= divisor)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
uint64_t one = 1;
|
||||
uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
|
||||
// assert(multiplier <= 0xffffffffUL);
|
||||
|
||||
return make_tuple(uint32_t(multiplier), shift);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor)
|
||||
{
|
||||
auto tmp = CalculateMagicNumbers(divisor);
|
||||
|
||||
return tmp[Number<0>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor)
|
||||
{
|
||||
auto tmp = CalculateMagicNumbers(divisor);
|
||||
|
||||
return tmp[Number<1>{}];
|
||||
}
|
||||
|
||||
// integral_constant<uint32_t, .>
|
||||
template <uint32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicNumbers(integral_constant<uint32_t, Divisor>)
|
||||
{
|
||||
constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});
|
||||
|
||||
constexpr uint32_t multiplier = tmp[Number<0>{}];
|
||||
constexpr uint32_t shift = tmp[Number<1>{}];
|
||||
|
||||
return make_tuple(integral_constant<uint32_t, multiplier>{},
|
||||
integral_constant<uint32_t, shift>{});
|
||||
}
|
||||
|
||||
template <uint32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>)
|
||||
{
|
||||
constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});
|
||||
|
||||
return integral_constant<uint32_t, multiplier>{};
|
||||
}
|
||||
|
||||
template <uint32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicShift(integral_constant<uint32_t, Divisor>)
|
||||
{
|
||||
constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});
|
||||
|
||||
return integral_constant<uint32_t, shift>{};
|
||||
}
|
||||
|
||||
// integral_constant<int32_t, .>
|
||||
template <int32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicNumbers(integral_constant<int32_t, Divisor>)
|
||||
{
|
||||
return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
|
||||
}
|
||||
|
||||
template <int32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicMultiplier(integral_constant<int32_t, Divisor>)
|
||||
{
|
||||
return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
|
||||
}
|
||||
|
||||
template <int32_t Divisor>
|
||||
__host__ __device__ static constexpr auto
|
||||
CalculateMagicShift(integral_constant<int32_t, Divisor>)
|
||||
{
|
||||
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
|
||||
}
|
||||
|
||||
// magic division for uint32_t
|
||||
__host__ __device__ static constexpr uint32_t
|
||||
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32;
|
||||
return (tmp + dividend) >> shift;
|
||||
}
|
||||
|
||||
#if 1 // debug
|
||||
// HACK: magic division for int32_t
|
||||
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
|
||||
// non-negative for result to be correct
|
||||
// TODO: figure out how to do magic number divison for int32_t as dividended
|
||||
__host__ __device__ static constexpr int32_t
|
||||
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
|
||||
uint32_t tmp =
|
||||
(static_cast<uint64_t>(dividend_u32) * static_cast<uint64_t>(multiplier)) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
#else
|
||||
// the inline ASM is producing wrong result
|
||||
__host__ __device__ static int32_t
|
||||
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t r;
|
||||
asm volatile("\n \
|
||||
v_mul_hi_u32 %0, %1, %2 \n \
|
||||
v_add_u32_e32 %0, %1, %0 \n \
|
||||
v_lshrrev_b32_e32 %0, %3, %0 \n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(as_type<uint32_t>(dividend_i32)), "s"(multiplier), "s"(shift));
|
||||
|
||||
return as_type<int32_t>(r);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
216
composable_kernel/include/utility/math.hpp
Normal file
216
composable_kernel/include/utility/math.hpp
Normal file
@@ -0,0 +1,216 @@
|
||||
#ifndef CK_MATH_HPP
|
||||
#define CK_MATH_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace math {
|
||||
|
||||
template <typename T, T s>
|
||||
struct scales
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct plus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
|
||||
};
|
||||
|
||||
struct multiplies
|
||||
{
|
||||
template <typename A, typename B>
|
||||
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const
|
||||
{
|
||||
return a * b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct maximize
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct minimize
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const
|
||||
{
|
||||
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
|
||||
|
||||
return (a + b - Number<1>{}) / b;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
|
||||
{
|
||||
return x / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
|
||||
{
|
||||
return (x + y - Number<1>{}) / y;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
|
||||
{
|
||||
return y * integer_divide_ceil(x, y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T max(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T max(T x, T y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
__host__ __device__ constexpr index_t max(Number<X>, index_t y)
|
||||
{
|
||||
return X > y ? X : y;
|
||||
}
|
||||
|
||||
template <index_t Y>
|
||||
__host__ __device__ constexpr index_t max(index_t x, Number<Y>)
|
||||
{
|
||||
return x > Y ? x : Y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
__host__ __device__ constexpr auto max(X x, Ys... ys)
|
||||
{
|
||||
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
||||
|
||||
return max(x, max(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T min(T x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T min(T x, T y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
template <index_t X>
|
||||
__host__ __device__ constexpr index_t min(Number<X>, index_t y)
|
||||
{
|
||||
return X < y ? X : y;
|
||||
}
|
||||
|
||||
template <index_t Y>
|
||||
__host__ __device__ constexpr index_t min(index_t x, Number<Y>)
|
||||
{
|
||||
return x < Y ? x : Y;
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
__host__ __device__ constexpr auto min(X x, Ys... ys)
|
||||
{
|
||||
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
||||
|
||||
return min(x, min(ys...));
|
||||
}
|
||||
|
||||
// greatest common divisor, aka highest common factor
|
||||
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
if(x < 0)
|
||||
{
|
||||
return gcd(-x, y);
|
||||
}
|
||||
else if(y < 0)
|
||||
{
|
||||
return gcd(x, -y);
|
||||
}
|
||||
else if(x == y || x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
else if(y == 0)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
else if(x > y)
|
||||
{
|
||||
return gcd(x % y, y);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gcd(x, y % x);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
|
||||
{
|
||||
constexpr auto r = gcd(X, Y);
|
||||
|
||||
return Number<r>{};
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, gcd(ys...));
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
template <typename X, typename Y>
|
||||
__host__ __device__ constexpr auto lcm(X x, Y y)
|
||||
{
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
|
||||
{
|
||||
return lcm(x, lcm(ys...));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct equal
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct less
|
||||
{
|
||||
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
12
composable_kernel/include/utility/multi_index.hpp
Normal file
12
composable_kernel/include/utility/multi_index.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#ifndef CK_MULTI_INDEX_HPP
|
||||
#define CK_MULTI_INDEX_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
#if CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX
|
||||
#include "array_multi_index.hpp"
|
||||
#else
|
||||
#include "statically_indexed_array_multi_index.hpp"
|
||||
#endif
|
||||
|
||||
#endif
|
||||
44
composable_kernel/include/utility/number.hpp
Normal file
44
composable_kernel/include/utility/number.hpp
Normal file
@@ -0,0 +1,44 @@
|
||||
#ifndef CK_NUMBER_HPP
|
||||
#define CK_NUMBER_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
using Number = integral_constant<index_t, N>;
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
|
||||
{
|
||||
return Number<X + Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y <= X, "wrong!");
|
||||
return Number<X - Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
|
||||
{
|
||||
return Number<X * Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y > 0, "wrong!");
|
||||
return Number<X / Y>{};
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
|
||||
{
|
||||
static_assert(Y > 0, "wrong!");
|
||||
return Number<X % Y>{};
|
||||
}
|
||||
} // namespace ck
|
||||
#endif
|
||||
22
composable_kernel/include/utility/print.hpp
Normal file
22
composable_kernel/include/utility/print.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
#ifndef CK_PRINT_HPP
|
||||
#define CK_PRINT_HPP
|
||||
|
||||
#include "array.hpp"
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_helper.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void print_array(const char* s, T a)
|
||||
{
|
||||
constexpr index_t nsize = a.Size();
|
||||
|
||||
printf("%s size %d, {", s, nsize);
|
||||
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
|
||||
printf("}\n");
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
880
composable_kernel/include/utility/sequence.hpp
Normal file
880
composable_kernel/include/utility/sequence.hpp
Normal file
@@ -0,0 +1,880 @@
|
||||
#ifndef CK_SEQUENCE_HPP
|
||||
#define CK_SEQUENCE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "type.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t, index_t, index_t>
|
||||
struct static_for;
|
||||
|
||||
template <index_t...>
|
||||
struct Sequence;
|
||||
|
||||
template <typename Seq, index_t I>
|
||||
struct sequence_split;
|
||||
|
||||
template <typename>
|
||||
struct sequence_reverse;
|
||||
|
||||
template <typename>
|
||||
struct sequence_map_inverse;
|
||||
|
||||
template <typename>
|
||||
struct is_valid_sequence_map;
|
||||
|
||||
template <index_t I, index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
|
||||
|
||||
template <typename Seq>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Seq);
|
||||
|
||||
template <index_t... Is>
|
||||
struct Sequence
|
||||
{
|
||||
using Type = Sequence;
|
||||
using data_type = index_t;
|
||||
|
||||
static constexpr index_t mSize = sizeof...(Is);
|
||||
|
||||
__host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
|
||||
|
||||
__host__ __device__ static constexpr auto GetSize() { return Size(); }
|
||||
|
||||
__host__ __device__ static constexpr index_t At(index_t I)
|
||||
{
|
||||
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
|
||||
const index_t mData[mSize + 1] = {Is..., 0};
|
||||
return mData[I];
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr auto At(Number<I>)
|
||||
{
|
||||
static_assert(I < mSize, "wrong! I too large");
|
||||
|
||||
return Number<At(I)>{};
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static constexpr auto Get(Number<I>)
|
||||
{
|
||||
return At(Number<I>{});
|
||||
}
|
||||
|
||||
template <typename I>
|
||||
__host__ __device__ constexpr auto operator[](I i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(IRs),
|
||||
"wrong! reorder map should have the same size as Sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
|
||||
|
||||
return Sequence<Type::At(Number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
// MapOld2New is Sequence<...>
|
||||
template <typename MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
|
||||
{
|
||||
static_assert(MapOld2New::Size() == Size(),
|
||||
"wrong! reorder map should have the same size as Sequence to be rerodered");
|
||||
|
||||
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
|
||||
|
||||
return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Reverse()
|
||||
{
|
||||
return typename sequence_reverse<Type>::type{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Front()
|
||||
{
|
||||
static_assert(mSize > 0, "wrong!");
|
||||
return At(Number<0>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Back()
|
||||
{
|
||||
static_assert(mSize > 0, "wrong!");
|
||||
return At(Number<mSize - 1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
|
||||
|
||||
__host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<Xs..., Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto PushFront(Number<Xs>...)
|
||||
{
|
||||
return Sequence<Xs..., Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<Is..., Xs...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs>
|
||||
__host__ __device__ static constexpr auto PushBack(Number<Xs>...)
|
||||
{
|
||||
return Sequence<Is..., Xs...>{};
|
||||
}
|
||||
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ static constexpr auto Extract(Number<Ns>...)
|
||||
{
|
||||
return Sequence<Type::At(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t... Ns>
|
||||
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
|
||||
{
|
||||
return Sequence<Type::At(Number<Ns>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t I, index_t X>
|
||||
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
|
||||
{
|
||||
static_assert(I < Size(), "wrong!");
|
||||
|
||||
using seq_split = sequence_split<Type, I>;
|
||||
constexpr auto seq_left = typename seq_split::left_type{};
|
||||
constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
|
||||
|
||||
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
__host__ __device__ static constexpr auto Transform(F f)
|
||||
{
|
||||
return Sequence<f(Is)...>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static void Print()
|
||||
{
|
||||
printf("{");
|
||||
printf("size %d, ", index_t{Size()});
|
||||
static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
|
||||
printf("}");
|
||||
}
|
||||
};
|
||||
|
||||
// merge sequence
|
||||
template <typename Seq, typename... Seqs>
|
||||
struct sequence_merge
|
||||
{
|
||||
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using type = Sequence<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
template <typename Seq>
|
||||
struct sequence_merge<Seq>
|
||||
{
|
||||
using type = Seq;
|
||||
};
|
||||
|
||||
// generate sequence
|
||||
template <index_t NSize, typename F>
|
||||
struct sequence_gen
|
||||
{
|
||||
template <index_t IBegin, index_t NRemain, typename G>
|
||||
struct sequence_gen_impl
|
||||
{
|
||||
static constexpr index_t NRemainLeft = NRemain / 2;
|
||||
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
|
||||
static constexpr index_t IMiddle = IBegin + NRemainLeft;
|
||||
|
||||
using type = typename sequence_merge<
|
||||
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
|
||||
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I, typename G>
|
||||
struct sequence_gen_impl<I, 1, G>
|
||||
{
|
||||
static constexpr index_t Is = G{}(Number<I>{});
|
||||
using type = Sequence<Is>;
|
||||
};
|
||||
|
||||
template <index_t I, typename G>
|
||||
struct sequence_gen_impl<I, 0, G>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
using type = typename sequence_gen_impl<0, NSize, F>::type;
|
||||
};
|
||||
|
||||
// arithmetic sequence
|
||||
template <index_t IBegin, index_t IEnd, index_t Increment>
|
||||
struct arithmetic_sequence_gen
|
||||
{
|
||||
struct F
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t i) const
|
||||
{
|
||||
return i * Increment + IBegin;
|
||||
}
|
||||
};
|
||||
|
||||
using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
|
||||
};
|
||||
|
||||
// uniform sequence
|
||||
template <index_t NSize, index_t I>
|
||||
struct uniform_sequence_gen
|
||||
{
|
||||
struct F
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
|
||||
};
|
||||
|
||||
using type = typename sequence_gen<NSize, F>::type;
|
||||
};
|
||||
|
||||
// reverse inclusive scan (with init) sequence
|
||||
template <typename, typename, index_t>
|
||||
struct sequence_reverse_inclusive_scan;
|
||||
|
||||
template <index_t I, index_t... Is, typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
|
||||
{
|
||||
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
|
||||
|
||||
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
|
||||
|
||||
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
|
||||
};
|
||||
|
||||
template <index_t I, typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
|
||||
{
|
||||
using type = Sequence<Reduce{}(I, Init)>;
|
||||
};
|
||||
|
||||
template <typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
|
||||
{
|
||||
using type = Sequence<>;
|
||||
};
|
||||
|
||||
// split sequence
|
||||
template <typename Seq, index_t I>
|
||||
struct sequence_split
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.Size();
|
||||
|
||||
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
|
||||
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
|
||||
|
||||
using left_type = decltype(Seq::Extract(range0{}));
|
||||
using right_type = decltype(Seq::Extract(range1{}));
|
||||
};
|
||||
|
||||
// reverse sequence
|
||||
template <typename Seq>
|
||||
struct sequence_reverse
|
||||
{
|
||||
static constexpr index_t NSize = Seq{}.Size();
|
||||
|
||||
using seq_split = sequence_split<Seq, NSize / 2>;
|
||||
using type = typename sequence_merge<
|
||||
typename sequence_reverse<typename seq_split::right_type>::type,
|
||||
typename sequence_reverse<typename seq_split::left_type>::type>::type;
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct sequence_reverse<Sequence<I>>
|
||||
{
|
||||
using type = Sequence<I>;
|
||||
};
|
||||
|
||||
template <index_t I0, index_t I1>
|
||||
struct sequence_reverse<Sequence<I0, I1>>
|
||||
{
|
||||
using type = Sequence<I1, I0>;
|
||||
};
|
||||
|
||||
#if 1
|
||||
template <typename Reduce, typename Seq, typename... Seqs>
|
||||
struct sequence_reduce
|
||||
{
|
||||
using type = typename sequence_reduce<Reduce,
|
||||
Seq,
|
||||
typename sequence_reduce<Reduce, Seqs...>::type>::type;
|
||||
};
|
||||
|
||||
template <typename Reduce, index_t... Xs, index_t... Ys>
|
||||
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
|
||||
{
|
||||
using type = Sequence<Reduce{}(Xs, Ys)...>;
|
||||
};
|
||||
|
||||
template <typename Reduce, typename Seq>
|
||||
struct sequence_reduce<Reduce, Seq>
|
||||
{
|
||||
using type = Seq;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename Values, typename Ids, typename Compare>
|
||||
struct sequence_sort_impl
|
||||
{
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl
|
||||
{
|
||||
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
|
||||
|
||||
static constexpr index_t chosen_value =
|
||||
choose_left ? LeftValues::Front() : RightValues::Front();
|
||||
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
|
||||
|
||||
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
|
||||
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
|
||||
|
||||
using new_left_values =
|
||||
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
|
||||
using new_left_ids =
|
||||
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
|
||||
|
||||
using new_right_values =
|
||||
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
|
||||
using new_right_ids =
|
||||
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
|
||||
|
||||
using merge = sorted_sequence_merge_impl<new_left_values,
|
||||
new_left_ids,
|
||||
new_right_values,
|
||||
new_right_ids,
|
||||
new_merged_values,
|
||||
new_merged_ids,
|
||||
Comp>;
|
||||
// this is output
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
Sequence<>,
|
||||
Sequence<>,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
|
||||
};
|
||||
|
||||
template <typename RightValues,
|
||||
typename RightIds,
|
||||
typename MergedValues,
|
||||
typename MergedIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge_impl<Sequence<>,
|
||||
Sequence<>,
|
||||
RightValues,
|
||||
RightIds,
|
||||
MergedValues,
|
||||
MergedIds,
|
||||
Comp>
|
||||
{
|
||||
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
|
||||
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
|
||||
};
|
||||
|
||||
template <typename LeftValues,
|
||||
typename LeftIds,
|
||||
typename RightValues,
|
||||
typename RightIds,
|
||||
typename Comp>
|
||||
struct sorted_sequence_merge
|
||||
{
|
||||
using merge = sorted_sequence_merge_impl<LeftValues,
|
||||
LeftIds,
|
||||
RightValues,
|
||||
RightIds,
|
||||
Sequence<>,
|
||||
Sequence<>,
|
||||
Comp>;
|
||||
|
||||
using merged_values = typename merge::merged_values;
|
||||
using merged_ids = typename merge::merged_ids;
|
||||
};
|
||||
|
||||
static constexpr index_t nsize = Values::Size();
|
||||
|
||||
using split_unsorted_values = sequence_split<Values, nsize / 2>;
|
||||
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
|
||||
|
||||
using left_unsorted_values = typename split_unsorted_values::left_type;
|
||||
using left_unsorted_ids = typename split_unsorted_ids::left_type;
|
||||
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
|
||||
using left_sorted_values = typename left_sort::sorted_values;
|
||||
using left_sorted_ids = typename left_sort::sorted_ids;
|
||||
|
||||
using right_unsorted_values = typename split_unsorted_values::right_type;
|
||||
using right_unsorted_ids = typename split_unsorted_ids::right_type;
|
||||
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
|
||||
using right_sorted_values = typename right_sort::sorted_values;
|
||||
using right_sorted_ids = typename right_sort::sorted_ids;
|
||||
|
||||
using merged_sorted = sorted_sequence_merge<left_sorted_values,
|
||||
left_sorted_ids,
|
||||
right_sorted_values,
|
||||
right_sorted_ids,
|
||||
Compare>;
|
||||
|
||||
using sorted_values = typename merged_sorted::merged_values;
|
||||
using sorted_ids = typename merged_sorted::merged_ids;
|
||||
};
|
||||
|
||||
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
|
||||
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
|
||||
{
|
||||
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
|
||||
|
||||
using sorted_values =
|
||||
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
|
||||
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
|
||||
};
|
||||
|
||||
template <index_t Value, index_t Id, typename Compare>
|
||||
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
|
||||
{
|
||||
using sorted_values = Sequence<Value>;
|
||||
using sorted_ids = Sequence<Id>;
|
||||
};
|
||||
|
||||
template <typename Compare>
|
||||
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
|
||||
{
|
||||
using sorted_values = Sequence<>;
|
||||
using sorted_ids = Sequence<>;
|
||||
};
|
||||
|
||||
template <typename Values, typename Compare>
|
||||
struct sequence_sort
|
||||
{
|
||||
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
|
||||
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
|
||||
|
||||
// this is output
|
||||
using type = typename sort::sorted_values;
|
||||
using sorted2unsorted_map = typename sort::sorted_ids;
|
||||
};
|
||||
|
||||
template <typename Values, typename Less, typename Equal>
|
||||
struct sequence_unique_sort
|
||||
{
|
||||
template <typename RemainValues,
|
||||
typename RemainIds,
|
||||
typename UniquifiedValues,
|
||||
typename UniquifiedIds,
|
||||
typename Eq>
|
||||
struct sorted_sequence_uniquify_impl
|
||||
{
|
||||
static constexpr index_t current_value = RemainValues::Front();
|
||||
static constexpr index_t current_id = RemainIds::Front();
|
||||
|
||||
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
|
||||
|
||||
using new_remain_values = decltype(RemainValues::PopFront());
|
||||
using new_remain_ids = decltype(RemainIds::PopFront());
|
||||
|
||||
using new_uniquified_values =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
|
||||
UniquifiedValues>::type;
|
||||
|
||||
using new_uniquified_ids =
|
||||
typename conditional<is_unique_value,
|
||||
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
|
||||
UniquifiedIds>::type;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
|
||||
new_remain_ids,
|
||||
new_uniquified_values,
|
||||
new_uniquified_ids,
|
||||
Eq>;
|
||||
|
||||
// this is output
|
||||
using uniquified_values = typename uniquify::uniquified_values;
|
||||
using uniquified_ids = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
|
||||
struct sorted_sequence_uniquify_impl<Sequence<>,
|
||||
Sequence<>,
|
||||
UniquifiedValues,
|
||||
UniquifiedIds,
|
||||
Eq>
|
||||
{
|
||||
using uniquified_values = UniquifiedValues;
|
||||
using uniquified_ids = UniquifiedIds;
|
||||
};
|
||||
|
||||
template <typename SortedValues, typename SortedIds, typename Eq>
|
||||
struct sorted_sequence_uniquify
|
||||
{
|
||||
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
|
||||
decltype(SortedIds::PopFront()),
|
||||
Sequence<SortedValues::Front()>,
|
||||
Sequence<SortedIds::Front()>,
|
||||
Eq>;
|
||||
|
||||
using uniquified_values = typename uniquify::uniquified_values;
|
||||
using uniquified_ids = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
using sort = sequence_sort<Values, Less>;
|
||||
using sorted_values = typename sort::type;
|
||||
using sorted_ids = typename sort::sorted2unsorted_map;
|
||||
|
||||
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
|
||||
|
||||
// this is output
|
||||
using type = typename uniquify::uniquified_values;
|
||||
using sorted2unsorted_map = typename uniquify::uniquified_ids;
|
||||
};
|
||||
|
||||
template <typename SeqMap>
|
||||
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
|
||||
typename sequence_sort<SeqMap, math::less<index_t>>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename SeqMap>
|
||||
struct sequence_map_inverse
|
||||
{
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
|
||||
struct sequence_map_inverse_impl
|
||||
{
|
||||
static constexpr auto new_y2x =
|
||||
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
|
||||
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
|
||||
type;
|
||||
};
|
||||
|
||||
template <typename X2Y, typename WorkingY2X, index_t XBegin>
|
||||
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
|
||||
{
|
||||
using type = WorkingY2X;
|
||||
};
|
||||
|
||||
using type =
|
||||
typename sequence_map_inverse_impl<SeqMap,
|
||||
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
|
||||
0,
|
||||
SeqMap::Size()>::type;
|
||||
};
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs + Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs - Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs * Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs / Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
|
||||
|
||||
return Sequence<(Xs % Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs + Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs - Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs * Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs / Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t Y>
|
||||
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
|
||||
{
|
||||
return Sequence<(Xs % Y)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y + Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y - Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y * Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y / Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t Y, index_t... Xs>
|
||||
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<(Y % Xs)...>{};
|
||||
}
|
||||
|
||||
template <index_t I, index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
|
||||
{
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
template <typename Seq>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Seq)
|
||||
{
|
||||
static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
|
||||
return sequence_pop_front(Seq::Reverse()).Reverse();
|
||||
}
|
||||
|
||||
template <typename... Seqs>
|
||||
__host__ __device__ constexpr auto merge_sequences(Seqs...)
|
||||
{
|
||||
return typename sequence_merge<Seqs...>::type{};
|
||||
}
|
||||
|
||||
template <typename F, index_t... Xs>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
|
||||
{
|
||||
return Sequence<f(Xs)...>{};
|
||||
}
|
||||
|
||||
template <typename F, index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
|
||||
|
||||
return Sequence<f(Xs, Ys)...>{};
|
||||
}
|
||||
|
||||
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
|
||||
Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
|
||||
"Dim not the same");
|
||||
|
||||
return Sequence<f(Xs, Ys, Zs)...>{};
|
||||
}
|
||||
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
|
||||
}
|
||||
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
|
||||
.PushBack(Number<Init>{});
|
||||
}
|
||||
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
|
||||
}
|
||||
|
||||
template <typename Seq, index_t... Is>
|
||||
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
|
||||
{
|
||||
return Sequence<Seq::At(Number<Is>{})...>{};
|
||||
}
|
||||
|
||||
#if 1
|
||||
namespace detail {
|
||||
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
|
||||
struct pick_sequence_elements_by_mask_impl
|
||||
{
|
||||
using new_work_seq = typename conditional<RemainMask::Front(),
|
||||
decltype(WorkSeq::PushBack(RemainSeq::Front())),
|
||||
WorkSeq>::type;
|
||||
|
||||
using type =
|
||||
typename pick_sequence_elements_by_mask_impl<new_work_seq,
|
||||
decltype(RemainSeq::PopFront()),
|
||||
decltype(RemainMask::PopFront())>::type;
|
||||
};
|
||||
|
||||
template <typename WorkSeq>
|
||||
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
|
||||
{
|
||||
using type = WorkSeq;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename Seq, typename Mask>
|
||||
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
|
||||
{
|
||||
static_assert(Seq::Size() == Mask::Size(), "wrong!");
|
||||
|
||||
return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <typename WorkSeq, typename RemainValues, typename RemainIds>
|
||||
struct modify_sequence_elements_by_ids_impl
|
||||
{
|
||||
using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
|
||||
|
||||
using type =
|
||||
typename modify_sequence_elements_by_ids_impl<new_work_seq,
|
||||
decltype(RemainValues::PopFront()),
|
||||
decltype(RemainIds::PopFront())>::type;
|
||||
};
|
||||
|
||||
template <typename WorkSeq>
|
||||
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
|
||||
{
|
||||
using type = WorkSeq;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename Seq, typename Values, typename Ids>
|
||||
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
|
||||
{
|
||||
static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");
|
||||
|
||||
return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr index_t
|
||||
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
|
||||
{
|
||||
index_t result = Init;
|
||||
|
||||
for(index_t i = 0; i < Seq::Size(); ++i)
|
||||
{
|
||||
result = f(result, Seq::At(i));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// TODO: a generic any_of for any container
|
||||
template <typename Seq, typename F>
|
||||
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
|
||||
{
|
||||
bool flag = false;
|
||||
|
||||
for(index_t i = 0; i < Seq::Size(); ++i)
|
||||
{
|
||||
flag = flag || f(Seq::At(i));
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
// TODO: a generic all_of for any container
|
||||
template <typename Seq, typename F>
|
||||
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
|
||||
{
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < Seq::Size(); ++i)
|
||||
{
|
||||
flag = flag && f(Seq::At(i));
|
||||
}
|
||||
|
||||
return flag;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
36
composable_kernel/include/utility/sequence_helper.hpp
Normal file
36
composable_kernel/include/utility/sequence_helper.hpp
Normal file
@@ -0,0 +1,36 @@
|
||||
#ifndef CK_SEQUENCE_HELPER_HPP
|
||||
#define CK_SEQUENCE_HELPER_HPP
|
||||
|
||||
#include "tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto make_sequence(Number<Is>...)
|
||||
{
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
// F returns index_t
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_sequence(F, Number<N>)
|
||||
{
|
||||
return typename sequence_gen<N, F>::type{};
|
||||
}
|
||||
|
||||
// F returns Number<>
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
|
||||
{
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
71
composable_kernel/include/utility/static_buffer.hpp
Normal file
71
composable_kernel/include/utility/static_buffer.hpp
Normal file
@@ -0,0 +1,71 @@
|
||||
#ifndef CK_STATIC_BUFFER_HPP
|
||||
#define CK_STATIC_BUFFER_HPP
|
||||
|
||||
#include "statically_indexed_array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace,
|
||||
typename T,
|
||||
index_t N,
|
||||
bool InvalidElementUseNumericalZeroValue>
|
||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
using base = StaticallyIndexedArray<T, N>;
|
||||
|
||||
T invalid_element_value_ = T{0};
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
|
||||
: base{}, invalid_element_value_{invalid_element_value}
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return is_valid_element ? At(i) : T{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
return is_valid_element ? At(i) : invalid_element_value_;
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
|
||||
{
|
||||
if(is_valid_element)
|
||||
{
|
||||
At(i) = x;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N, true>{};
|
||||
}
|
||||
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,40 @@
|
||||
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
|
||||
#define CK_STATICALLY_INDEXED_ARRAY_HPP
|
||||
|
||||
#include "functional2.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T, index_t NSize>
|
||||
__host__ __device__ constexpr auto generate_same_type_tuple()
|
||||
{
|
||||
return generate_tuple([](auto) -> T { return T{}; }, Number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename T, index_t NSize>
|
||||
using same_type_tuple = decltype(generate_same_type_tuple<T, NSize>());
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename T, index_t NSize>
|
||||
using StaticallyIndexedArray = detail::same_type_tuple<T, NSize>;
|
||||
|
||||
template <typename X, typename... Xs>
|
||||
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
|
||||
{
|
||||
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
|
||||
}
|
||||
|
||||
// make empty StaticallyIndexedArray
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto make_statically_indexed_array()
|
||||
{
|
||||
return StaticallyIndexedArray<X, 0>();
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,108 @@
|
||||
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
|
||||
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t N>
|
||||
using MultiIndex = StaticallyIndexedArray<index_t, N>;
|
||||
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
|
||||
{
|
||||
return make_statically_indexed_array<index_t>(index_t{xs}...);
|
||||
}
|
||||
|
||||
template <index_t NSize>
|
||||
__host__ __device__ constexpr auto make_zero_multi_index()
|
||||
{
|
||||
return unpack([](auto... xs) { return make_multi_index(xs...); },
|
||||
typename uniform_sequence_gen<NSize, 0>::type{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto to_multi_index(const T& x)
|
||||
{
|
||||
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
|
||||
}
|
||||
|
||||
// Here should use MultiIndex<NSize>, instead of Tuple<Ys...>, although the former
|
||||
// is the alias of the latter. This is because compiler cannot infer the NSize if
|
||||
// using MultiIndex<NSize>
|
||||
// TODO: how to fix this?
|
||||
template <typename... Ys, typename X>
|
||||
__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Ys, typename X>
|
||||
__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; });
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename Y>
|
||||
__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
Tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] + y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename Y>
|
||||
__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
Tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] - y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename Y>
|
||||
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
Tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
// MultiIndex = index_t * MultiIndex
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
|
||||
{
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
Tuple<Xs...> r;
|
||||
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a * x[i]; });
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs>
|
||||
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
|
||||
{
|
||||
printf("{");
|
||||
printf("MultiIndex, ");
|
||||
printf("size %d,", index_t{sizeof...(Xs)});
|
||||
static_for<0, sizeof...(Xs), 1>{}(
|
||||
[&](auto i) { printf("%d ", static_cast<index_t>(x.At(i))); });
|
||||
printf("}");
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
21
composable_kernel/include/utility/synchronization.hpp
Normal file
21
composable_kernel/include/utility/synchronization.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#ifndef CK_SYNCHRONIZATION_AMD_HPP
|
||||
#define CK_SYNCHRONIZATION_AMD_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ void block_sync_lds()
|
||||
{
|
||||
#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
166
composable_kernel/include/utility/tuple.hpp
Normal file
166
composable_kernel/include/utility/tuple.hpp
Normal file
@@ -0,0 +1,166 @@
|
||||
#ifndef CK_TUPLE_HPP
|
||||
#define CK_TUPLE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <index_t>
|
||||
struct TupleElementKey
|
||||
{
|
||||
__host__ __device__ constexpr TupleElementKey() = default;
|
||||
};
|
||||
|
||||
template <typename Key, typename Data>
|
||||
struct TupleElement
|
||||
{
|
||||
__host__ __device__ constexpr TupleElement() = default;
|
||||
|
||||
template <typename T,
|
||||
typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
|
||||
{
|
||||
}
|
||||
|
||||
Data mData;
|
||||
};
|
||||
|
||||
template <typename Key, typename Data>
|
||||
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
|
||||
{
|
||||
return static_cast<const Data&>(x.mData);
|
||||
}
|
||||
|
||||
template <typename Key, typename Data>
|
||||
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
|
||||
{
|
||||
return x.mData;
|
||||
}
|
||||
|
||||
// TODO: not sure the use of reference is correct
|
||||
template <typename Key, typename Data>
|
||||
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
|
||||
{
|
||||
return static_cast<Data&&>(x.mData);
|
||||
}
|
||||
|
||||
template <typename Indices, typename... Xs>
|
||||
struct TupleImpl;
|
||||
|
||||
template <index_t... Is, typename... Xs>
|
||||
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
|
||||
{
|
||||
__host__ __device__ constexpr TupleImpl() = default;
|
||||
|
||||
template <typename Y,
|
||||
typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
|
||||
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr TupleImpl(Y&& y)
|
||||
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
|
||||
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
|
||||
{
|
||||
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
|
||||
"wrong! inconsistent size");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
|
||||
{
|
||||
return get_tuple_element<TupleElementKey<I>>(*this);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
|
||||
{
|
||||
return get_tuple_element<TupleElementKey<I>>(*this);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename... Xs>
|
||||
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
|
||||
{
|
||||
using base =
|
||||
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
|
||||
|
||||
__host__ __device__ constexpr Tuple() = default;
|
||||
|
||||
template <typename Y,
|
||||
typename enable_if<sizeof...(Xs) == 1 &&
|
||||
!is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Ys,
|
||||
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
|
||||
false>
|
||||
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& At(Number<I>) const
|
||||
{
|
||||
static_assert(I < base::Size(), "wrong! out of range");
|
||||
return base::GetElementByKey(detail::TupleElementKey<I>{});
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& At(Number<I>)
|
||||
{
|
||||
static_assert(I < base::Size(), "wrong! out of range");
|
||||
return base::GetElementByKey(detail::TupleElementKey<I>{});
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& operator()(Number<I> i)
|
||||
{
|
||||
return At(i);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto operator=(const T& a)
|
||||
{
|
||||
static_assert(T::Size() == Size(), "wrong! size not the same");
|
||||
|
||||
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <typename... Xs>
|
||||
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
|
||||
{
|
||||
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
80
composable_kernel/include/utility/tuple_helper.hpp
Normal file
80
composable_kernel/include/utility/tuple_helper.hpp
Normal file
@@ -0,0 +1,80 @@
|
||||
#ifndef CK_TUPLE_HELPER_HPP
|
||||
#define CK_TUPLE_HELPER_HPP
|
||||
|
||||
#include "functional4.hpp"
|
||||
#include "tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename... Ts>
|
||||
struct is_known_at_compile_time<Tuple<Ts...>>
|
||||
{
|
||||
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
|
||||
{
|
||||
return container_reduce(
|
||||
Tuple<Ts...>{},
|
||||
[](auto x, bool r) {
|
||||
return is_known_at_compile_time<
|
||||
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
|
||||
r;
|
||||
},
|
||||
true);
|
||||
}
|
||||
|
||||
static constexpr bool value = IsKnownAtCompileTime();
|
||||
};
|
||||
|
||||
template <typename F, index_t N>
|
||||
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
|
||||
{
|
||||
return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
|
||||
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, index_t... Is>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
|
||||
{
|
||||
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y>
|
||||
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename X, typename Y, typename Z>
|
||||
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
||||
{
|
||||
return detail::transform_tuples_impl(
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
56
composable_kernel/include/utility/type.hpp
Normal file
56
composable_kernel/include/utility/type.hpp
Normal file
@@ -0,0 +1,56 @@
|
||||
#ifndef CK_TYPE_HPP
|
||||
#define CK_TYPE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct is_same : public integral_constant<bool, false>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
struct is_same<X, X> : public integral_constant<bool, true>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using remove_reference_t = typename std::remove_reference<T>::type;
|
||||
|
||||
template <typename T>
|
||||
using remove_cv_t = typename std::remove_cv<T>::type;
|
||||
|
||||
template <typename T>
|
||||
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
|
||||
|
||||
template <typename T>
|
||||
struct is_known_at_compile_time;
|
||||
|
||||
template <>
|
||||
struct is_known_at_compile_time<index_t>
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename T, T X>
|
||||
struct is_known_at_compile_time<integral_constant<T, X>>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
__host__ __device__ constexpr Y as_type(X x)
|
||||
{
|
||||
union AsType
|
||||
{
|
||||
X x;
|
||||
Y y;
|
||||
};
|
||||
|
||||
return AsType{x}.y;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
14
composable_kernel/include/utility/utility.hpp
Normal file
14
composable_kernel/include/utility/utility.hpp
Normal file
@@ -0,0 +1,14 @@
|
||||
#ifndef CK_UTILITY_HPP
|
||||
#define CK_UTILITY_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
|
||||
|
||||
__device__ index_t get_block_1d_id() { return blockIdx.x; }
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,370 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
|
||||
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
|
||||
constexpr index_t KPerThread = CK_PARAM_KPerThread;
|
||||
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
|
||||
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
|
||||
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
|
||||
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K_M0_M1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1>;
|
||||
using ABlockTransferThreadClusterLengths_K_M0_M1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_M1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_M1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K_N0_N1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1>;
|
||||
using BBlockTransferThreadClusterLengths_K_N0_N1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_N1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_N1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k_m0_m1_grid_desc,
|
||||
void* p_b_k_n0_n1_grid_desc,
|
||||
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW));
|
||||
|
||||
const auto a_k_m_grid_desc = descs[I0];
|
||||
const auto b_k_n_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<decltype(a_k_m0_m1_grid_desc)*>(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc;
|
||||
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
|
||||
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
};
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1));
|
||||
|
||||
constexpr auto a_k_m_grid_desc = descs[I0];
|
||||
constexpr auto b_k_n_grid_desc = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AKMGridDesc = decltype(a_k_m_grid_desc);
|
||||
using BKNGridDesc = decltype(b_k_n_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
constexpr auto b_k_n0_n1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k_m0_m1_grid_desc =
|
||||
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc =
|
||||
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
};
|
||||
@@ -0,0 +1,358 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
int wi,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using BGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
@@ -0,0 +1,357 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
|
||||
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
|
||||
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
|
||||
|
||||
constexpr index_t MPerWave = CK_PARAM_MPerWave;
|
||||
constexpr index_t NPerWave = CK_PARAM_NPerWave;
|
||||
constexpr index_t MRepeat = CK_PARAM_MRepeat;
|
||||
constexpr index_t NRepeat = CK_PARAM_NRepeat;
|
||||
constexpr index_t K1 = CK_PARAM_K1;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M_K1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using BBlockTransferThreadSliceLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N_K1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
|
||||
|
||||
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
|
||||
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
|
||||
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
|
||||
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
|
||||
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
|
||||
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
|
||||
int n,
|
||||
int hi,
|
||||
int wi,
|
||||
int c,
|
||||
int k,
|
||||
int y,
|
||||
int x,
|
||||
int convStrideH,
|
||||
int convStrideW,
|
||||
int convDilationY,
|
||||
int convDilationX,
|
||||
int leftPadH,
|
||||
int leftPadW,
|
||||
int rightPadH,
|
||||
int rightPadW,
|
||||
void* p_a_k0_m_k1_grid_desc,
|
||||
void* p_b_k0_n_k1_grid_desc,
|
||||
void* p_c_m0_m1_m2_n_grid_desc,
|
||||
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
|
||||
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c));
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c));
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
|
||||
in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(convStrideH, convStrideW),
|
||||
make_tuple(convDilationY, convDilationX),
|
||||
make_tuple(leftPadH, leftPadW),
|
||||
make_tuple(rightPadH, rightPadW),
|
||||
Number<K1>{});
|
||||
|
||||
const auto a_k0_m_k1_grid_desc = descs[I0];
|
||||
const auto b_k0_n_k1_grid_desc = descs[I1];
|
||||
const auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
|
||||
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
{
|
||||
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
|
||||
a_k0_m_k1_grid_desc;
|
||||
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
|
||||
b_k0_n_k1_grid_desc;
|
||||
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
|
||||
c_m0_m1_m2_n_grid_desc;
|
||||
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
|
||||
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr auto in_n_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||
constexpr auto wei_k_y_x_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256));
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
Number<K1>{});
|
||||
|
||||
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
|
||||
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
|
||||
constexpr auto c_m_n_grid_desc = descs[I2];
|
||||
|
||||
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
|
||||
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
|
||||
using CMNGridDesc = decltype(c_m_n_grid_desc);
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{})));
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
false>;
|
||||
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
|
||||
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
|
||||
|
||||
const auto a_k0_m_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
|
||||
const auto c_m0_m1_m2_n_grid_desc =
|
||||
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
};
|
||||
@@ -0,0 +1,405 @@
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr auto GN0 = Number<CK_PARAM_GN0>{};
|
||||
constexpr auto GK1 = Number<CK_PARAM_GK1>{};
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
|
||||
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
|
||||
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
|
||||
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
|
||||
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBM10Xs>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBN10Xs>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>;
|
||||
using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>;
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>;
|
||||
using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>;
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
|
||||
Sequence<CK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
|
||||
using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
|
||||
|
||||
using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>;
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = 5;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBlockLoop);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
|
||||
|
||||
extern "C" __global__ void
|
||||
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_,
|
||||
int C_,
|
||||
int Hi_,
|
||||
int Wi_,
|
||||
int K_,
|
||||
int Y_,
|
||||
int X_,
|
||||
int ConvStrideH_,
|
||||
int ConvStrideW_,
|
||||
int ConvDilationH_,
|
||||
int ConvDilationW_,
|
||||
int InLeftPadH_,
|
||||
int InLeftPadW_,
|
||||
int InRightPadH_,
|
||||
int InRightPadW_,
|
||||
void* p_desc_tuple)
|
||||
{
|
||||
index_t N = static_cast<index_t>(N_);
|
||||
index_t C = static_cast<index_t>(C_);
|
||||
index_t Hi = static_cast<index_t>(Hi_);
|
||||
index_t Wi = static_cast<index_t>(Wi_);
|
||||
index_t K = static_cast<index_t>(K_);
|
||||
index_t Y = static_cast<index_t>(Y_);
|
||||
index_t X = static_cast<index_t>(X_);
|
||||
index_t ConvStrideH = static_cast<index_t>(ConvStrideH_);
|
||||
index_t ConvStrideW = static_cast<index_t>(ConvStrideW_);
|
||||
index_t ConvDilationH = static_cast<index_t>(ConvDilationH_);
|
||||
index_t ConvDilationW = static_cast<index_t>(ConvDilationW_);
|
||||
index_t InLeftPadH = static_cast<index_t>(InLeftPadH_);
|
||||
index_t InLeftPadW = static_cast<index_t>(InLeftPadW_);
|
||||
index_t InRightPadH = static_cast<index_t>(InRightPadH_);
|
||||
index_t InRightPadW = static_cast<index_t>(InRightPadW_);
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
const index_t Ho =
|
||||
(Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1;
|
||||
const index_t Wo =
|
||||
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi));
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X));
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
|
||||
|
||||
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(ConvStrideH, ConvStrideW),
|
||||
make_tuple(ConvDilationH, ConvDilationW),
|
||||
make_tuple(InLeftPadH, InLeftPadW),
|
||||
make_tuple(InRightPadH, InRightPadW),
|
||||
GN0,
|
||||
GK1);
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowStepHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
auto desc_tuple =
|
||||
make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1),
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1),
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1),
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
|
||||
*static_cast<decltype(desc_tuple)*>(p_desc_tuple) = desc_tuple;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void CONSTANT* p_desc_tuple)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
constexpr auto wei_k_c_y_x_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
|
||||
constexpr auto out_n_k_ho_wo_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
|
||||
|
||||
constexpr auto descs =
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
GN0,
|
||||
GK1);
|
||||
|
||||
constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using AGridStepHacks =
|
||||
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using BGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
|
||||
|
||||
using CGridStepHacks = decltype(make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
|
||||
|
||||
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
|
||||
|
||||
using BGridMoveSliceWindowStepHacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
|
||||
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1));
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
|
||||
decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1));
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
|
||||
decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
|
||||
using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{},
|
||||
BGridDesc_GK0_GN0_GN10_GN11_GK1{},
|
||||
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
|
||||
CGridBlockCluster_BlockId_To_GM10_GN10{}));
|
||||
|
||||
const auto desc_tuple = *reinterpret_cast<const DescTuple*>(
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
// TODO: how to cast?
|
||||
(const void*)p_desc_tuple
|
||||
#pragma clang diagnostic pop
|
||||
);
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2];
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3];
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
};
|
||||
125
external/rocm/include/bfloat16_dev.hpp
vendored
Normal file
125
external/rocm/include/bfloat16_dev.hpp
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
/*******************************************************************************
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
* Copyright (c) 2019 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
* of this software and associated documentation files (the "Software"), to deal
|
||||
* in the Software without restriction, including without limitation the rights
|
||||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
* copies of the Software, and to permit persons to whom the Software is
|
||||
* furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in all
|
||||
* copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
* SOFTWARE.
|
||||
*
|
||||
*******************************************************************************/
|
||||
#ifndef BFLOAT16_DEVICE_HPP
|
||||
#define BFLOAT16_DEVICE_HPP
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
#define EXECUTION_SPECIFIER __device__
|
||||
#else
|
||||
#define EXECUTION_SPECIFIER
|
||||
#endif // MIOPEN_BACKEND_HIP
|
||||
|
||||
typedef union
|
||||
{
|
||||
uint u32;
|
||||
ushort2 ushortx2;
|
||||
|
||||
// Composable kernels are written in HIP language. The language doesnt support
|
||||
// ushort2.hi or ushort2.low.
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
ushort ushortvec[2];
|
||||
#endif // MIOPEN_BACKEND_HIP
|
||||
float f32;
|
||||
} cvt_bf16_fp32_t;
|
||||
|
||||
EXECUTION_SPECIFIER float bfloat16_to_float(ushort src_val)
|
||||
{
|
||||
cvt_bf16_fp32_t target_val;
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
target_val.ushortx2 = make_ushort2(0, src_val);
|
||||
#else
|
||||
target_val.ushortx2 = (ushort2)(0, src_val);
|
||||
#endif
|
||||
|
||||
return target_val.f32;
|
||||
}
|
||||
|
||||
EXECUTION_SPECIFIER ushort float_to_bfloat16(float src_val)
|
||||
{
|
||||
cvt_bf16_fp32_t target_val;
|
||||
target_val.f32 = src_val;
|
||||
// BF16 round and NaN preservation code matches
|
||||
// https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h
|
||||
if((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN
|
||||
{
|
||||
// When all of the exponent bits are 1, the value is Inf or NaN.
|
||||
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
|
||||
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
|
||||
// bit being 1. Signaling NaN is indicated by the most significant
|
||||
// mantissa bit being 0 but some other bit(s) being 1. If any of the
|
||||
// lower 16 bits of the mantissa are 1, we set the least significant bit
|
||||
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
|
||||
// the bloat16's mantissa bits are all 0.
|
||||
if((target_val.u32 & 0xffff) != 0)
|
||||
{
|
||||
target_val.u32 |= 0x10000; // Preserve signaling NaN
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#ifdef MIOPEN_USE_RNE_BFLOAT16
|
||||
// When the exponent bits are not all 1s, then the value is zero, normal,
|
||||
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
|
||||
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
|
||||
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
|
||||
// least significant bits of the float mantissa are greater than 0x8000,
|
||||
// or if they are equal to 0x8000 and the least significant bit of the
|
||||
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
|
||||
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
|
||||
// has the value 0x7f, then incrementing it causes it to become 0x00 and
|
||||
// the exponent is incremented by one, which is the next higher FP value
|
||||
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
|
||||
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
|
||||
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
|
||||
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
|
||||
// incrementing it causes it to become an exponent of 0xFF and a mantissa
|
||||
// of 0x00, which is Inf, the next higher value to the unrounded value.
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
target_val.u32 += (0x7fff + (target_val.ushortvec[1] & 1));
|
||||
#else
|
||||
target_val.u32 +=
|
||||
(0x7fff + (target_val.ushortx2.hi & 1)); // Round to nearest, round to even
|
||||
#endif // MIOPEN_BACKEND_HIP
|
||||
#endif // MIOPEN_USE_RNE_BFLOAT16
|
||||
}
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
return target_val.ushortvec[1];
|
||||
#else
|
||||
return target_val.ushortx2.hi;
|
||||
#endif // MIOPEN_BACKEND_HIP
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // BFLOAT16_DEVICE_HPP
|
||||
2
host/CMakeLists.txt
Normal file
2
host/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(host_tensor)
|
||||
add_subdirectory(driver_offline)
|
||||
21
host/driver_offline/CMakeLists.txt
Normal file
21
host/driver_offline/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
include_directories(BEFORE
|
||||
include
|
||||
${PROJECT_SOURCE_DIR}/host/host_tensor/include
|
||||
${PROJECT_SOURCE_DIR}/host/solver/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
|
||||
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
|
||||
${PROJECT_SOURCE_DIR}/external/rocm/include
|
||||
)
|
||||
|
||||
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
|
||||
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
|
||||
|
||||
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
|
||||
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
|
||||
|
||||
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
|
||||
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)
|
||||
@@ -0,0 +1,330 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
I0,
|
||||
I0,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
|
||||
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmM,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
out_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,306 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
const Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
I0,
|
||||
I0,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto in_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
|
||||
|
||||
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(in_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<2, 0, 1>,
|
||||
Sequence<0, 2, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
#if 0
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
|
||||
#else
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
|
||||
#endif
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(in_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
true // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
out_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
in_gemmm_gemmn_grid_desc,
|
||||
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
in_m0_m1_m2_n_grid_step_hacks,
|
||||
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_dlops_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
|
||||
const auto in_gemmk_gemmn_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_dlops_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk_gemmm_grid_desc),
|
||||
decltype(in_gemmk_gemmn_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
GemmABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
|
||||
0, // ABlockTransferSrcVectorDim
|
||||
GemmABlockTransferSrcScalarPerVector_K,
|
||||
GemmABlockTransferDstScalarPerVector_M1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
GemmBBlockTransferSrcScalarPerVector_N1,
|
||||
GemmBBlockTransferDstScalarPerVector_N1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk_gemmm_grid_desc,
|
||||
in_gemmk_gemmn_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_step_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,280 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 0
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4]
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmKPack = 4;
|
||||
|
||||
constexpr index_t MRepeat = 1;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
|
||||
wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
#if 0
|
||||
float ave_time = launch_kernel_gemm_xdlops_v1
|
||||
#else
|
||||
float ave_time = launch_kernel_gemm_xdlops_v2
|
||||
#endif
|
||||
<BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmKPack,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<1, 0, 2>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_KPack,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused
|
||||
// with MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_dlops_v1r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 1;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 8, 4] for i8
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
|
||||
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
||||
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_dlops_v1r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM110Xs,
|
||||
GemmM11N11ThreadClusterN110Xs,
|
||||
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
|
||||
GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
|
||||
GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<0, 2, 1>,
|
||||
Sequence<1, 0, 2>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1>,
|
||||
2,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
|
||||
in_n_hi_wi_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 1, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 2, 0, 0>{}));
|
||||
|
||||
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
6,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
wei_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 256;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 4;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 8;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
|
||||
|
||||
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4+: M0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 5+: M1
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 6+: M2
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 7+: N1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: M0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 5-: M1
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_gemm_xdlops_v2r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmABlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
|
||||
Sequence<1, 0, 2>,
|
||||
Sequence<1, 0, 2>,
|
||||
2,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK1,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmK1,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
|
||||
7,
|
||||
GemmCThreadTransferDstScalarPerVector,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
|
||||
decltype(out_m0_m1_m2_n_grid_step_hacks),
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
|
||||
false // CAccessOrderMRepeatNRepeat
|
||||
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
|
||||
out_m0_m1_m2_n_grid_step_hacks,
|
||||
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t /* nrepeat */)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = out_n_k_ho_wo_lengths[I0];
|
||||
const auto K = out_n_k_ho_wo_lengths[I1];
|
||||
const auto C = wei_k_c_y_x_lengths[I1];
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_lengths[I2];
|
||||
const auto Wi = in_n_c_hi_wi_lengths[I3];
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_lengths[I2];
|
||||
const auto Wo = out_n_k_ho_wo_lengths[I3];
|
||||
|
||||
const auto Y = wei_k_c_y_x_lengths[I2];
|
||||
const auto X = wei_k_c_y_x_lengths[I3];
|
||||
|
||||
const auto C0 = C / Number<InWeiVectorSize>{};
|
||||
const auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
const auto K0 = K / Number<InWeiVectorSize>{};
|
||||
const auto K1 = Number<InWeiVectorSize>{};
|
||||
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, C0, Hi, Wi, C1}));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{K, C0, Y, X, C1}));
|
||||
Tensor<TOut> out_n_k0_ho_wo_k1(
|
||||
HostTensorDescriptor(std::initializer_list<index_t>{N, K0, Ho, Wo, K1}));
|
||||
|
||||
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
|
||||
in_n_c_hi_wi(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) =
|
||||
wei_k_c_y_x(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
|
||||
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
|
||||
|
||||
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
|
||||
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
|
||||
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
|
||||
const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = EPerBlock;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
|
||||
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = 16;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#else
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = EPerBlock;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
|
||||
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, 16>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#endif
|
||||
|
||||
constexpr auto conv_driver =
|
||||
#if 0
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
#else
|
||||
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
|
||||
#endif
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W>{};
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_desc,
|
||||
in_n_c0_hi_wi_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
|
||||
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) =
|
||||
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_contraction_dlops_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
|
||||
const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
|
||||
const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 1;
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = 128;
|
||||
constexpr index_t GN1PerBlockGN11 = 32;
|
||||
constexpr index_t GK0PerBlock = 8;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = 4;
|
||||
constexpr index_t BN1PerThreadBN11 = 4;
|
||||
constexpr index_t BK0PerThread = 1;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#elif 1
|
||||
// [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 2;
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = 128;
|
||||
constexpr index_t GN1PerBlockGN11 = 32;
|
||||
constexpr index_t GK0PerBlock = 8;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = 4;
|
||||
constexpr index_t BN1PerThreadBN11 = 4;
|
||||
constexpr index_t BK0PerThread = 1;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
|
||||
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
|
||||
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x,
|
||||
in_desc_n_c_hi_wi,
|
||||
out_desc_n_k_ho_wo,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GN0>{},
|
||||
Number<GK1>{});
|
||||
|
||||
const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
|
||||
const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
|
||||
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto in_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto out_grid_step_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
||||
|
||||
constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_contraction_dlops_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
|
||||
decltype(in_grid_desc_gk0_gn0_gn1_gk1),
|
||||
decltype(out_grid_desc_gm0_gm1_gn0_gn1),
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
CThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_grid_step_hacks),
|
||||
decltype(in_grid_step_hacks),
|
||||
decltype(out_grid_step_hacks),
|
||||
decltype(wei_grid_move_slice_window_step_hacks),
|
||||
decltype(in_grid_move_slice_window_step_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_grid_desc_gk0_gm0_gm1_gk1,
|
||||
in_grid_desc_gk0_gn0_gn1_gk1,
|
||||
out_grid_desc_gm0_gm1_gn0_gn1,
|
||||
wei_grid_step_hacks,
|
||||
in_grid_step_hacks,
|
||||
out_grid_step_hacks,
|
||||
wei_grid_move_slice_window_step_hacks,
|
||||
in_grid_move_slice_window_step_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = static_cast<float>(calculate_convolution_flops(
|
||||
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
286
host/driver_offline/include/driver_contraction_dlops_v1r2.hpp
Normal file
286
host/driver_offline/include/driver_contraction_dlops_v1r2.hpp
Normal file
@@ -0,0 +1,286 @@
|
||||
#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_contraction_dlops_v1r2.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
ck::index_t GM1PerBlockGM11,
|
||||
ck::index_t GN1PerBlockGN11,
|
||||
ck::index_t GK0PerBlock,
|
||||
ck::index_t BM1PerThreadBM11,
|
||||
ck::index_t BN1PerThreadBN11,
|
||||
ck::index_t BK0PerThread,
|
||||
typename BM10BN10ThreadClusterBM10Xs,
|
||||
typename BM10BN10ThreadClusterBN10Xs,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridStepHacks,
|
||||
typename BGridStepHacks,
|
||||
typename CGridStepHacks,
|
||||
typename AGridMoveSliceWindowStepHacks,
|
||||
typename BGridMoveSliceWindowStepHacks>
|
||||
__host__ float
|
||||
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
ck::index_t nrepeat)
|
||||
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction =
|
||||
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks>;
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
if(!GridwiseContraction::CheckValidity(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
|
||||
{
|
||||
throw std::runtime_error("wrong! "
|
||||
"GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
||||
"GM0_GM1_GN0_GN1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
|
||||
|
||||
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
|
||||
|
||||
// c_grid_block_cluster_blockid_to_gm10_gn10
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
|
||||
|
||||
const bool has_double_tail_k_block_loop =
|
||||
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
|
||||
|
||||
{
|
||||
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_contraction_dlops_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
#endif
|
||||
@@ -0,0 +1,349 @@
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t EPerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_W,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
#if 1
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_step_hacks),
|
||||
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf =
|
||||
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,364 @@
|
||||
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_dlops_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t HoPerBlock,
|
||||
ck::index_t WoPerBlock,
|
||||
ck::index_t EPerBlock,
|
||||
ck::index_t KPerThread,
|
||||
ck::index_t HoPerThread,
|
||||
ck::index_t WoPerThread,
|
||||
ck::index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector_E,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K,
|
||||
ck::index_t BThreadTransferSrcScalarPerVector_W,
|
||||
ck::index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
|
||||
<< std::endl;
|
||||
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
|
||||
<< std::endl;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, 0, OutRightPadH),
|
||||
make_pad_transform(Wo, 0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_step_hacks),
|
||||
decltype(b_e_n_ho_wo_global_step_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_step_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf =
|
||||
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc)) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
#endif
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user