Composable kernel init integration v3 (#1097)

* Squashed 'src/composable_kernel/' content from commit f6edda611

git-subtree-dir: src/composable_kernel
git-subtree-split: f6edda6119

* add solver ConvIgemmFwdV6r1DlopsNchwKcyxNkhw; rename static ck source files

* Squashed 'src/composable_kernel/' changes from f6edda611..5781adf5c

5781adf5c Update develop (#5) (#6)
97e6d514f Merge pull request #4 from ROCmSoftwarePlatform/separate_online_compile
7b1ec41e5 refactor
49c33aaea refactor
54b3e73d1 rename

git-subtree-dir: src/composable_kernel
git-subtree-split: 5781adf5cf

* fix

* refactor

* remove online compilation from CK

* refactor

* fix

* add ctest

* add c-style pointer cast

* vector/scalar pointer cast use c-style pointer cast instead of reinterpret_cast

* fix clang warning suppression

* tidy

* suppress cppcheck

* fix enum issue

* revert chagnes to hip build

* fix kernel filename

* update CK build script

* rename

* rename

* make innner product compatiable on gfx900

* Update src/include/miopen/solver/ck_utility_common.hpp

Co-authored-by: JD <Jehandad.Khan@amd.com>

* compiler parameter use stream

* use int instead of index_t in kernel wrapper

* DynamicBuffer, StaticBuffer, amd_buffer_load support customized value for invalid element

* refactor

* refactor

* change cmakelist

* change ck common utility

* fix

Co-authored-by: JD <Jehandad.Khan@amd.com>

[ROCm/composable_kernel commit: 6fe3627a9e]
This commit is contained in:
Chao Liu
2021-08-19 10:55:03 -05:00
committed by GitHub
commit ee428d2d6f
126 changed files with 30654 additions and 0 deletions

90
.clang-format Normal file
View File

@@ -0,0 +1,90 @@
---
Language: Cpp
AccessModifierOffset: 0
AlignAfterOpenBracket: Align
AlignConsecutiveAssignments: true
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: false
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: true
AfterControlStatement: true
AfterEnum: true
AfterFunction: true
AfterNamespace: false
AfterObjCDeclaration: true
AfterStruct: true
AfterUnion: true
BeforeCatch: true
BeforeElse: true
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
- Regex: '^(<|"(gtest|isl|json)/)'
Priority: 3
- Regex: '.*'
Priority: 1
IndentCaseLabels: false
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: true
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left
ReflowComments: true
SortIncludes: false
SpaceAfterCStyleCast: false
# SpaceAfterTemplateKeyword: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: Never
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...

3
.clang-tidy Normal file
View File

@@ -0,0 +1,3 @@
CheckOptions:
- key: bugprone-reserved-identifier.AllowedIdentifiers
value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__'

198
CMakeLists.txt Normal file
View File

@@ -0,0 +1,198 @@
cmake_minimum_required(VERSION 3.5)
project(composable_kernel)
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
include(CheckCXXCompilerFlag)
## C++
enable_language(CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
## OpenMP
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
# workaround issue hipcc in rocm3.5 cannot find openmp
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument")
set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5")
set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES})
set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES})
set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES})
else()
find_package(OpenMP REQUIRED)
endif()
message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}")
message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}")
message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}")
message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries(${OpenMP_pthread_LIBRARY})
## HIP
find_package(HIP REQUIRED)
message(STATUS "Build with HIP ${hip_VERSION}")
## half
#find_path(HALF_INCLUDE_DIR half.hpp)
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
# CMAKE_CXX_FLAGS
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
## tidy
include(EnableCompilerWarnings)
set(MIOPEN_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name)
if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+")
set(MIOPEN_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter)
# Enable tidy on hip
elseif(MIOPEN_BACKEND STREQUAL "HIP" OR MIOPEN_BACKEND STREQUAL "HIPNOGPU")
set(MIOPEN_TIDY_ERRORS ALL)
endif()
include(ClangTidy)
enable_clang_tidy(
CHECKS
*
-abseil-*
-android-cloexec-fopen
# Yea we shouldn't be using rand()
-cert-msc30-c
-bugprone-exception-escape
-bugprone-macro-parentheses
-cert-env33-c
-cert-msc32-c
-cert-msc50-cpp
-cert-msc51-cpp
-cert-dcl37-c
-cert-dcl51-cpp
-clang-analyzer-alpha.core.CastToStruct
-clang-analyzer-optin.performance.Padding
-clang-diagnostic-deprecated-declarations
-clang-diagnostic-extern-c-compat
-clang-diagnostic-unused-command-line-argument
-cppcoreguidelines-avoid-c-arrays
-cppcoreguidelines-avoid-magic-numbers
-cppcoreguidelines-explicit-virtual-functions
-cppcoreguidelines-init-variables
-cppcoreguidelines-macro-usage
-cppcoreguidelines-non-private-member-variables-in-classes
-cppcoreguidelines-pro-bounds-array-to-pointer-decay
-cppcoreguidelines-pro-bounds-constant-array-index
-cppcoreguidelines-pro-bounds-pointer-arithmetic
-cppcoreguidelines-pro-type-member-init
-cppcoreguidelines-pro-type-reinterpret-cast
-cppcoreguidelines-pro-type-union-access
-cppcoreguidelines-pro-type-vararg
-cppcoreguidelines-special-member-functions
-fuchsia-*
-google-explicit-constructor
-google-readability-braces-around-statements
-google-readability-todo
-google-runtime-int
-google-runtime-references
-hicpp-vararg
-hicpp-braces-around-statements
-hicpp-explicit-conversions
-hicpp-named-parameter
-hicpp-no-array-decay
# We really shouldn't use bitwise operators with signed integers, but
# opencl leaves us no choice
-hicpp-avoid-c-arrays
-hicpp-signed-bitwise
-hicpp-special-member-functions
-hicpp-uppercase-literal-suffix
-hicpp-use-auto
-hicpp-use-equals-default
-hicpp-use-override
-llvm-header-guard
-llvm-include-order
#-llvmlibc-*
-llvmlibc-restrict-system-libc-headers
-llvmlibc-callee-namespace
-llvmlibc-implementation-in-namespace
-llvm-else-after-return
-llvm-qualified-auto
-misc-misplaced-const
-misc-non-private-member-variables-in-classes
-misc-no-recursion
-modernize-avoid-bind
-modernize-avoid-c-arrays
-modernize-pass-by-value
-modernize-use-auto
-modernize-use-default-member-init
-modernize-use-equals-default
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-performance-unnecessary-value-param
-readability-braces-around-statements
-readability-else-after-return
# we are not ready to use it, but very useful
-readability-function-cognitive-complexity
-readability-isolate-declaration
-readability-magic-numbers
-readability-named-parameter
-readability-uppercase-literal-suffix
-readability-convert-member-functions-to-static
-readability-qualified-auto
-readability-redundant-string-init
# too many narrowing conversions in our code
-bugprone-narrowing-conversions
-cppcoreguidelines-narrowing-conversions
-altera-struct-pack-align
-cppcoreguidelines-prefer-member-initializer
${MIOPEN_TIDY_CHECKS}
${MIOPEN_TIDY_ERRORS}
HEADER_FILTER
"\.hpp$"
EXTRA_ARGS
-DMIOPEN_USE_CLANG_TIDY
)
include(CppCheck)
enable_cppcheck(
CHECKS
warning
style
performance
portability
SUPPRESS
ConfigurationNotChecked
constStatement
duplicateCondition
noExplicitConstructor
passedByValue
preprocessorErrorDirective
shadowVariable
unusedFunction
unusedPrivateFunction
unusedStructMember
unmatchedSuppression
FORCE
SOURCES
host/host_tensor/src
host/driver_offline/src
composable_kernel/src/kernel_wrapper
INCLUDE
host/host_tensor/include
host/solver/include
host/driver_offline/include
composable_kernel/include/*
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_BINARY_DIR}/include
DEFINE
CPPCHECK=1
__linux__=1
)
add_subdirectory(host)

177
README.md Normal file
View File

@@ -0,0 +1,177 @@
# How to build and run
# Docker
```
docker run \
-it \
--rm \
--privileged \
--group-add sudo \
-w /root/workspace \
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
rocm/tensorflow:rocm4.2-tf2.4-dev \
/bin/bash
```
# Install Boost for online compilation
https://www.boost.org/doc/libs/1_66_0/more/getting_started/unix-variants.html#easy-build-and-install
# Build
Add path of Boost
```
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
```
```
mkdir build && cd build
```
cmake cmd. Need to Specify target ID, example below is gfx908
```
cmake \
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 -O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \
-D HIP_ONLINE_COMPILER_FLAGS="-DCK_AMD_GPU_GFX908" \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
..
```
Build drivers: \
``conv_fwd_driver_offline`` is (offline compilation) driver for forward convolution, \
``conv_bwd_driver_offline`` is (offline compilation) driver for backward-data convolution \
``conv_fwd_driver_online`` is (online compilation) driver for forward convolution
```
make -j conv_fwd_driver_offline
make -j conv_bwd_driver_offline
make -j conv_fwd_driver_online
```
# Run
* layout: 0 = NCHW; 1 = NHWC
* algo: algorithm
* verify: 0 = no verification; 1 = do verification
* init: 0 ~ 5. initialization method
* log: 0 = no log; 1 = do log
* repeat: number of time kernel being launched
```
######################################################## layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
./host/driver_offline/conv_bwd_driver_offline 1 5 0 0 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
```
# Result
Forward convoltuion, FP16, NCHW
```
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
layout: 0
in: dim 4, lengths {128, 192, 71, 71}, strides {967872, 5041, 71, 1}
wei: dim 4, lengths {256, 192, 3, 3}, strides {1728, 9, 3, 1}
out: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1296, 36, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, }
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{216, 256, 8}
b_k0_n_k1_grid_desc{216, 165888, 8}
c_m_n_grid_desc{ 256, 165888}
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 1.4155 ms, 103.686 TFlop/s
```
Forward convoltuion, FP16, NCHW
```
./host/driver_offline/conv_fwd_driver_offline 0 4 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
layout: 0
in: dim 4, lengths {256, 256, 14, 14}, strides {50176, 196, 14, 1}
wei: dim 4, lengths {1024, 256, 3, 3}, strides {2304, 9, 3, 1}
out: dim 4, lengths {256, 1024, 14, 14}, strides {200704, 196, 14, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw
a_k0_m_k1_grid_desc{288, 1024, 8}
b_k0_n_k1_grid_desc{288, 50176, 8}
c_m_n_grid_desc{ 1024, 50176}
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 2.21357 ms, 106.959 TFlop/s
```
Forward convolution, FP16, NHWC
```
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1
layout: 1
in: dim 4, lengths {128, 71, 71, 192}, strides {967872, 13632, 192, 1}
wei: dim 4, lengths {256, 3, 3, 192}, strides {1728, 576, 192, 1}
out: dim 4, lengths {128, 36, 36, 256}, strides {331776, 9216, 256, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {2, 2, }
ConvDilations size 2, {1, 1, }
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{216, 165888, 8}
b_k0_n_k1_grid_desc{216, 256, 8}
c_m_n_grid_desc{ 165888, 256}
launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 1.12014 ms, 131.025 TFlop/s
```
Forward convolution, FP16, NHWC
```
./host/driver_offline/conv_fwd_driver_offline 1 5 0 0 0 1 256 1024 256 3 3 14 14 1 1 1 1 1 1 1 1
layout: 1
in: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
wei: dim 4, lengths {1024, 3, 3, 256}, strides {2304, 768, 256, 1}
out: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024}
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 1.86877 ms, 126.693 TFlop/s
```
Backward data convolution, FP16, NHWC
```
./host/driver_offline/conv_bwd_driver_offline 1 1 0 3 0 1 256 256 1024 3 3 14 14 1 1 1 1 1 1 1 1
layout: 1
in: dim 4, lengths {256, 14, 14, 1024}, strides {200704, 14336, 1024, 1}
wei: dim 4, lengths {256, 3, 3, 1024}, strides {9216, 3072, 1024, 1}
out: dim 4, lengths {256, 14, 14, 256}, strides {50176, 3584, 256, 1}
InLeftPads size 2, {1, 1, }
InRightPads size 2, {1, 1, }
ConvStrides size 2, {1, 1, }
ConvDilations size 2, {1, 1, }
device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk
a_k0_m_k1_grid_desc{288, 50176, 8}
b_k0_n_k1_grid_desc{288, 1024, 8}
c_m_n_grid_desc{ 50176, 1024}
launch_and_time_kernel: grid_dim {1568, 1, 1}, block_dim {256, 1, 1}
Warm up
Start running 1 times...
Average time : 2.22461 ms, 106.428 TFlop/s
```

34
cmake/Analyzers.cmake Normal file
View File

@@ -0,0 +1,34 @@
################################################################################
#
# MIT License
#
# Copyright (c) 2017 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################
if(NOT TARGET analyze)
add_custom_target(analyze)
endif()
function(mark_as_analyzer)
add_dependencies(analyze ${ARGN})
endfunction()

162
cmake/ClangTidy.cmake Normal file
View File

@@ -0,0 +1,162 @@
################################################################################
#
# MIT License
#
# Copyright (c) 2017 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################
include(CMakeParseArguments)
include(Analyzers)
get_filename_component(CLANG_TIDY_EXE_HINT "${CMAKE_CXX_COMPILER}" PATH)
find_program(CLANG_TIDY_EXE
NAMES
clang-tidy
clang-tidy-5.0
clang-tidy-4.0
clang-tidy-3.9
clang-tidy-3.8
clang-tidy-3.7
clang-tidy-3.6
clang-tidy-3.5
HINTS
${CLANG_TIDY_EXE_HINT}
PATH_SUFFIXES
compiler/bin
PATHS
/opt/rocm/llvm/bin
/opt/rocm/hcc
/usr/local/opt/llvm/bin
)
function(find_clang_tidy_version VAR)
execute_process(COMMAND ${CLANG_TIDY_EXE} -version OUTPUT_VARIABLE VERSION_OUTPUT)
separate_arguments(VERSION_OUTPUT_LIST UNIX_COMMAND "${VERSION_OUTPUT}")
list(FIND VERSION_OUTPUT_LIST "version" VERSION_INDEX)
if(VERSION_INDEX GREATER 0)
math(EXPR VERSION_INDEX "${VERSION_INDEX} + 1")
list(GET VERSION_OUTPUT_LIST ${VERSION_INDEX} VERSION)
set(${VAR} ${VERSION} PARENT_SCOPE)
else()
set(${VAR} "0.0" PARENT_SCOPE)
endif()
endfunction()
if( NOT CLANG_TIDY_EXE )
message( STATUS "Clang tidy not found" )
set(CLANG_TIDY_VERSION "0.0")
else()
find_clang_tidy_version(CLANG_TIDY_VERSION)
message( STATUS "Clang tidy found: ${CLANG_TIDY_VERSION}")
endif()
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CLANG_TIDY_FIXIT_DIR ${CMAKE_BINARY_DIR}/fixits)
file(MAKE_DIRECTORY ${CLANG_TIDY_FIXIT_DIR})
set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CLANG_TIDY_FIXIT_DIR})
macro(enable_clang_tidy)
set(options ANALYZE_TEMPORARY_DTORS ALL)
set(oneValueArgs HEADER_FILTER)
set(multiValueArgs CHECKS ERRORS EXTRA_ARGS)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(REPLACE ";" "," CLANG_TIDY_CHECKS "${PARSE_CHECKS}")
string(REPLACE ";" "," CLANG_TIDY_ERRORS "${PARSE_ERRORS}")
set(CLANG_TIDY_EXTRA_ARGS)
foreach(ARG ${PARSE_EXTRA_ARGS})
list(APPEND CLANG_TIDY_EXTRA_ARGS "-extra-arg=${ARG}")
endforeach()
set(CLANG_TIDY_ALL)
if(PARSE_ALL)
set(CLANG_TIDY_ALL ALL)
endif()
message(STATUS "Clang tidy checks: ${CLANG_TIDY_CHECKS}")
if (${PARSE_ANALYZE_TEMPORARY_DTORS})
set(CLANG_TIDY_ANALYZE_TEMPORARY_DTORS "-analyze-temporary-dtors")
endif()
if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0")
set(CLANG_TIDY_ERRORS_ARG "")
else()
set(CLANG_TIDY_ERRORS_ARG "-warnings-as-errors='${CLANG_TIDY_ERRORS}'")
endif()
if (${CLANG_TIDY_VERSION} VERSION_LESS "3.9.0")
set(CLANG_TIDY_QUIET_ARG "")
else()
set(CLANG_TIDY_QUIET_ARG "-quiet")
endif()
if(PARSE_HEADER_FILTER)
string(REPLACE "$" "$$" CLANG_TIDY_HEADER_FILTER "${PARSE_HEADER_FILTER}")
else()
set(CLANG_TIDY_HEADER_FILTER ".*")
endif()
set(CLANG_TIDY_COMMAND
${CLANG_TIDY_EXE}
${CLANG_TIDY_QUIET_ARG}
-p ${CMAKE_BINARY_DIR}
-checks='${CLANG_TIDY_CHECKS}'
${CLANG_TIDY_ERRORS_ARG}
${CLANG_TIDY_EXTRA_ARGS}
${CLANG_TIDY_ANALYZE_TEMPORARY_DTORS}
-header-filter='${CLANG_TIDY_HEADER_FILTER}'
)
add_custom_target(tidy ${CLANG_TIDY_ALL})
mark_as_analyzer(tidy)
add_custom_target(tidy-base)
add_custom_target(tidy-make-fixit-dir COMMAND ${CMAKE_COMMAND} -E make_directory ${CLANG_TIDY_FIXIT_DIR})
add_custom_target(tidy-rm-fixit-dir COMMAND ${CMAKE_COMMAND} -E remove_directory ${CLANG_TIDY_FIXIT_DIR})
add_dependencies(tidy-make-fixit-dir tidy-rm-fixit-dir)
add_dependencies(tidy-base tidy-make-fixit-dir)
endmacro()
function(clang_tidy_check TARGET)
get_target_property(SOURCES ${TARGET} SOURCES)
# TODO: Use generator expressions instead
# COMMAND ${CLANG_TIDY_COMMAND} $<TARGET_PROPERTY:${TARGET},SOURCES>
# COMMAND ${CLANG_TIDY_COMMAND} $<JOIN:$<TARGET_PROPERTY:${TARGET},SOURCES>, >
foreach(SOURCE ${SOURCES})
if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS"))
string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file)
set(tidy_target tidy-target-${TARGET}-${tidy_file})
add_custom_target(${tidy_target}
# for some targets clang-tidy not able to get information from .clang-tidy
DEPENDS ${SOURCE}
COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml"
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..."
)
add_dependencies(${tidy_target} ${TARGET})
add_dependencies(${tidy_target} tidy-base)
add_dependencies(tidy ${tidy_target})
endif()
endforeach()
endfunction()

130
cmake/CppCheck.cmake Normal file
View File

@@ -0,0 +1,130 @@
################################################################################
#
# MIT License
#
# Copyright (c) 2017 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################
include(CMakeParseArguments)
include(ProcessorCount)
include(Analyzers)
find_program(CPPCHECK_EXE
NAMES
cppcheck
PATHS
/opt/rocm/bin
)
ProcessorCount(CPPCHECK_JOBS)
set(CPPCHECK_BUILD_DIR ${CMAKE_BINARY_DIR}/cppcheck-build)
file(MAKE_DIRECTORY ${CPPCHECK_BUILD_DIR})
set_property(DIRECTORY APPEND PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CPPCHECK_BUILD_DIR})
macro(enable_cppcheck)
set(options FORCE)
set(oneValueArgs)
set(multiValueArgs CHECKS SUPPRESS DEFINE UNDEFINE INCLUDE SOURCES)
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(REPLACE ";" "," CPPCHECK_CHECKS "${PARSE_CHECKS}")
string(REPLACE ";" "\n" CPPCHECK_SUPPRESS "${PARSE_SUPPRESS};*:/usr/*")
file(WRITE ${CMAKE_BINARY_DIR}/cppcheck-supressions "${CPPCHECK_SUPPRESS}")
set(CPPCHECK_DEFINES)
foreach(DEF ${PARSE_DEFINE})
set(CPPCHECK_DEFINES "${CPPCHECK_DEFINES} -D${DEF}")
endforeach()
set(CPPCHECK_UNDEFINES)
foreach(DEF ${PARSE_UNDEFINE})
set(CPPCHECK_UNDEFINES "${CPPCHECK_UNDEFINES} -U${DEF}")
endforeach()
set(CPPCHECK_INCLUDES)
foreach(INC ${PARSE_INCLUDE})
set(CPPCHECK_INCLUDES "${CPPCHECK_INCLUDES} -I${INC}")
endforeach()
# set(CPPCHECK_FORCE)
set(CPPCHECK_FORCE "--project=${CMAKE_BINARY_DIR}/compile_commands.json")
if(PARSE_FORCE)
set(CPPCHECK_FORCE --force)
endif()
set(SOURCES)
set(GLOBS)
foreach(SOURCE ${PARSE_SOURCES})
get_filename_component(ABS_SOURCE ${SOURCE} ABSOLUTE)
if(EXISTS ${ABS_SOURCE})
if(IS_DIRECTORY ${ABS_SOURCE})
set(GLOBS "${GLOBS} ${ABS_SOURCE}/*.cpp ${ABS_SOURCE}/*.hpp ${ABS_SOURCE}/*.cxx ${ABS_SOURCE}/*.c ${ABS_SOURCE}/*.h")
else()
set(SOURCES "${SOURCES} ${ABS_SOURCE}")
endif()
else()
set(GLOBS "${GLOBS} ${ABS_SOURCE}")
endif()
endforeach()
file(WRITE ${CMAKE_BINARY_DIR}/cppcheck.cmake "
file(GLOB_RECURSE GSRCS ${GLOBS})
set(CPPCHECK_COMMAND
${CPPCHECK_EXE}
-q
# -v
# --report-progress
${CPPCHECK_FORCE}
--cppcheck-build-dir=${CPPCHECK_BUILD_DIR}
--platform=native
--template=gcc
--error-exitcode=1
-j ${CPPCHECK_JOBS}
${CPPCHECK_DEFINES}
${CPPCHECK_UNDEFINES}
${CPPCHECK_INCLUDES}
--enable=${CPPCHECK_CHECKS}
--inline-suppr
--suppressions-list=${CMAKE_BINARY_DIR}/cppcheck-supressions
${SOURCES} \${GSRCS}
)
string(REPLACE \";\" \" \" CPPCHECK_SHOW_COMMAND \"\${CPPCHECK_COMMAND}\")
message(\"\${CPPCHECK_SHOW_COMMAND}\")
execute_process(
COMMAND \${CPPCHECK_COMMAND}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE RESULT
)
if(NOT RESULT EQUAL 0)
message(FATAL_ERROR \"Cppcheck failed\")
endif()
")
add_custom_target(cppcheck
COMMAND ${CMAKE_COMMAND} -P ${CMAKE_BINARY_DIR}/cppcheck.cmake
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "cppcheck: Running cppcheck..."
)
mark_as_analyzer(cppcheck)
endmacro()

355
cmake/DoxygenDoc.cmake Normal file
View File

@@ -0,0 +1,355 @@
################################################################################
#
# MIT License
#
# Copyright (c) 2017 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################
include(CMakeParseArguments)
include(MainDoc)
find_program(DOXYGEN_EXECUTABLE NAMES doxygen
PATH_SUFFIXES bin
DOC "Doxygen documentation generator"
)
mark_as_advanced(DOXYGEN_EXECUTABLE)
find_path(DOT_EXECUTABLE NAMES dot
PATH_SUFFIXES bin
DOC "Graphviz"
)
mark_as_advanced(DOT_EXECUTABLE)
set(DOXYGEN_ARGS
ABBREVIATE_BRIEF
ALIASES
ALLEXTERNALS
ALLOW_UNICODE_NAMES
ALPHABETICAL_INDEX
ALWAYS_DETAILED_SEC
AUTOLINK_SUPPORT
BINARY_TOC
BRIEF_MEMBER_DESC
BUILTIN_STL_SUPPORT
CALLER_GRAPH
CALL_GRAPH
CASE_SENSE_NAMES
CHM_FILE
CHM_INDEX_ENCODING
CITE_BIB_FILES
CLANG_ASSISTED_PARSING
CLANG_OPTIONS
CLASS_DIAGRAMS
CLASS_GRAPH
COLLABORATION_GRAPH
COLS_IN_ALPHA_INDEX
COMPACT_LATEX
COMPACT_RTF
CPP_CLI_SUPPORT
CREATE_SUBDIRS
DIAFILE_DIRS
DIA_PATH
DIRECTORY_GRAPH
DISABLE_INDEX
DISTRIBUTE_GROUP_DOC
DOCBOOK_OUTPUT
DOCBOOK_PROGRAMLISTING
DOCSET_BUNDLE_ID
DOCSET_FEEDNAME
DOCSET_PUBLISHER_ID
DOCSET_PUBLISHER_NAME
DOTFILE_DIRS
DOT_CLEANUP
DOT_FONTNAME
DOT_FONTPATH
DOT_FONTSIZE
DOT_GRAPH_MAX_NODES
DOT_IMAGE_FORMAT
DOT_MULTI_TARGETS
DOT_NUM_THREADS
# DOT_PATH
DOT_TRANSPARENT
DOXYFILE_ENCODING
ECLIPSE_DOC_ID
ENABLED_SECTIONS
ENABLE_PREPROCESSING
ENUM_VALUES_PER_LINE
EXAMPLE_PATH
EXAMPLE_PATTERNS
EXAMPLE_RECURSIVE
EXCLUDE
EXCLUDE_PATTERNS
EXCLUDE_SYMBOLS
EXCLUDE_SYMLINKS
EXPAND_AS_DEFINED
EXPAND_ONLY_PREDEF
EXTENSION_MAPPING
EXTERNAL_GROUPS
EXTERNAL_PAGES
EXTERNAL_SEARCH
EXTERNAL_SEARCH_ID
EXTRACT_ALL
EXTRACT_ANON_NSPACES
EXTRACT_LOCAL_CLASSES
EXTRACT_LOCAL_METHODS
EXTRACT_PACKAGE
EXTRACT_PRIVATE
EXTRACT_STATIC
EXTRA_PACKAGES
EXTRA_SEARCH_MAPPINGS
EXT_LINKS_IN_WINDOW
FILE_PATTERNS
FILE_VERSION_FILTER
FILTER_PATTERNS
FILTER_SOURCE_FILES
FILTER_SOURCE_PATTERNS
FORCE_LOCAL_INCLUDES
FORMULA_FONTSIZE
FORMULA_TRANSPARENT
FULL_PATH_NAMES
GENERATE_AUTOGEN_DEF
GENERATE_BUGLIST
GENERATE_CHI
GENERATE_DEPRECATEDLIST
GENERATE_DOCBOOK
GENERATE_DOCSET
GENERATE_ECLIPSEHELP
GENERATE_HTML
GENERATE_HTMLHELP
GENERATE_LATEX
GENERATE_LEGEND
GENERATE_MAN
GENERATE_PERLMOD
GENERATE_QHP
GENERATE_RTF
GENERATE_TAGFILE
GENERATE_TESTLIST
GENERATE_TODOLIST
GENERATE_TREEVIEW
GENERATE_XML
GRAPHICAL_HIERARCHY
GROUP_GRAPHS
GROUP_NESTED_COMPOUNDS
# HAVE_DOT
HHC_LOCATION
HIDE_COMPOUND_REFERENCE
HIDE_FRIEND_COMPOUNDS
HIDE_IN_BODY_DOCS
HIDE_SCOPE_NAMES
HIDE_UNDOC_CLASSES
HIDE_UNDOC_MEMBERS
HIDE_UNDOC_RELATIONS
HTML_COLORSTYLE_GAMMA
HTML_COLORSTYLE_HUE
HTML_COLORSTYLE_SAT
HTML_DYNAMIC_SECTIONS
HTML_EXTRA_FILES
HTML_EXTRA_STYLESHEET
HTML_FILE_EXTENSION
HTML_FOOTER
HTML_HEADER
HTML_INDEX_NUM_ENTRIES
HTML_OUTPUT
HTML_STYLESHEET
HTML_TIMESTAMP
IDL_PROPERTY_SUPPORT
IGNORE_PREFIX
IMAGE_PATH
INCLUDED_BY_GRAPH
INCLUDE_FILE_PATTERNS
INCLUDE_GRAPH
INCLUDE_PATH
INHERIT_DOCS
INLINE_GROUPED_CLASSES
INLINE_INFO
INLINE_INHERITED_MEMB
INLINE_SIMPLE_STRUCTS
INLINE_SOURCES
INPUT
INPUT_ENCODING
INPUT_FILTER
INTERACTIVE_SVG
INTERNAL_DOCS
JAVADOC_AUTOBRIEF
LATEX_BATCHMODE
LATEX_BIB_STYLE
LATEX_CMD_NAME
LATEX_EXTRA_FILES
LATEX_EXTRA_STYLESHEET
LATEX_FOOTER
LATEX_HEADER
LATEX_HIDE_INDICES
LATEX_OUTPUT
LATEX_SOURCE_CODE
LATEX_TIMESTAMP
LAYOUT_FILE
LOOKUP_CACHE_SIZE
MACRO_EXPANSION
MAKEINDEX_CMD_NAME
MAN_EXTENSION
MAN_LINKS
MAN_OUTPUT
MAN_SUBDIR
MARKDOWN_SUPPORT
MATHJAX_CODEFILE
MATHJAX_EXTENSIONS
MATHJAX_FORMAT
MATHJAX_RELPATH
MAX_DOT_GRAPH_DEPTH
MAX_INITIALIZER_LINES
MSCFILE_DIRS
MSCGEN_PATH
MULTILINE_CPP_IS_BRIEF
OPTIMIZE_FOR_FORTRAN
OPTIMIZE_OUTPUT_FOR_C
OPTIMIZE_OUTPUT_JAVA
OPTIMIZE_OUTPUT_VHDL
OUTPUT_DIRECTORY
OUTPUT_LANGUAGE
PAPER_TYPE
PDF_HYPERLINKS
PERLMOD_LATEX
PERLMOD_MAKEVAR_PREFIX
PERLMOD_PRETTY
PERL_PATH
PLANTUML_CFG_FILE
PLANTUML_INCLUDE_PATH
PLANTUML_JAR_PATH
PREDEFINED
PROJECT_BRIEF
PROJECT_LOGO
PROJECT_NAME
PROJECT_NUMBER
QCH_FILE
QHG_LOCATION
QHP_CUST_FILTER_ATTRS
QHP_CUST_FILTER_NAME
QHP_NAMESPACE
QHP_SECT_FILTER_ATTRS
QHP_VIRTUAL_FOLDER
QT_AUTOBRIEF
QUIET
RECURSIVE
REFERENCED_BY_RELATION
REFERENCES_LINK_SOURCE
REFERENCES_RELATION
REPEAT_BRIEF
RTF_EXTENSIONS_FILE
RTF_HYPERLINKS
RTF_OUTPUT
RTF_SOURCE_CODE
RTF_STYLESHEET_FILE
SEARCHDATA_FILE
SEARCHENGINE
SEARCHENGINE_URL
SEARCH_INCLUDES
SEPARATE_MEMBER_PAGES
SERVER_BASED_SEARCH
SHORT_NAMES
SHOW_FILES
SHOW_GROUPED_MEMB_INC
SHOW_INCLUDE_FILES
SHOW_NAMESPACES
SHOW_USED_FILES
SIP_SUPPORT
SKIP_FUNCTION_MACROS
SORT_BRIEF_DOCS
SORT_BY_SCOPE_NAME
SORT_GROUP_NAMES
SORT_MEMBERS_CTORS_1ST
SORT_MEMBER_DOCS
SOURCE_BROWSER
SOURCE_TOOLTIPS
STRICT_PROTO_MATCHING
STRIP_CODE_COMMENTS
STRIP_FROM_INC_PATH
STRIP_FROM_PATH
SUBGROUPING
TAB_SIZE
TAGFILES
TCL_SUBST
TEMPLATE_RELATIONS
TOC_EXPAND
TOC_INCLUDE_HEADINGS
TREEVIEW_WIDTH
TYPEDEF_HIDES_STRUCT
UML_LIMIT_NUM_FIELDS
UML_LOOK
USE_HTAGS
USE_MATHJAX
USE_MDFILE_AS_MAINPAGE
USE_PDFLATEX
VERBATIM_HEADERS
WARNINGS
WARN_AS_ERROR
WARN_FORMAT
WARN_IF_DOC_ERROR
WARN_IF_UNDOCUMENTED
WARN_LOGFILE
WARN_NO_PARAMDOC
XML_OUTPUT
XML_PROGRAMLISTING
)
set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file")
function(add_doxygen_doc)
set(options)
set(oneValueArgs)
set(multiValueArgs DEPENDS ${DOXYGEN_ARGS})
cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
file(WRITE ${DOXYGEN_CONFIG_FILE} "# Auto-generated doxygen configuration file\n")
foreach(ARG ${DOXYGEN_ARGS})
if(PARSE_${ARG})
string(REPLACE ";" " " ARG_VALUE ${PARSE_${ARG}})
file(APPEND ${DOXYGEN_CONFIG_FILE} "\n${ARG} = ${ARG_VALUE}\n")
endif()
endforeach()
if(PARSE_OUTPUT_DIRECTORY)
if(NOT EXISTS ${PARSE_OUTPUT_DIRECTORY})
file(MAKE_DIRECTORY ${PARSE_OUTPUT_DIRECTORY})
endif()
endif()
if(DOT_EXECUTABLE)
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nDOT_PATH = \"${DOT_EXECUTABLE}\"\n")
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = YES\n")
else()
file(APPEND ${DOXYGEN_CONFIG_FILE} "\nHAVE_DOT = NO\n")
endif()
add_custom_target(doxygen
${DOXYGEN_EXECUTABLE} ${DOXYGEN_CONFIG_FILE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "Building documentation with doxygen"
)
if(PARSE_OUTPUT_DIRECTORY)
clean_doc_output(${PARSE_OUTPUT_DIRECTORY})
endif()
mark_as_doc(doxygen)
if(PARSE_DEPENDS)
add_dependencies(doxygen ${PARSE_DEPENDS})
endif()
endfunction()

View File

@@ -0,0 +1,110 @@
################################################################################
#
# MIT License
#
# Copyright (c) 2017 Advanced Micro Devices, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
################################################################################
# - Enable warning all for gcc/clang or use /W4 for visual studio
## Strict warning level
if (MSVC)
# Use the highest warning level for visual studio.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /w")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /w")
# set(CMAKE_CXX_WARNING_LEVEL 4)
# if (CMAKE_CXX_FLAGS MATCHES "/W[0-4]")
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
# else ()
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4")
# endif ()
# set(CMAKE_C_WARNING_LEVEL 4)
# if (CMAKE_C_FLAGS MATCHES "/W[0-4]")
# string(REGEX REPLACE "/W[0-4]" "/W4" CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
# else ()
# set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W4")
# endif ()
else()
foreach(COMPILER C CXX)
set(CMAKE_COMPILER_WARNINGS)
# use -Wall for gcc and clang
list(APPEND CMAKE_COMPILER_WARNINGS
-Wall
-Wextra
-Wcomment
-Wendif-labels
-Wformat
-Winit-self
-Wreturn-type
-Wsequence-point
# Shadow is broken on gcc when using lambdas
# -Wshadow
-Wswitch
-Wtrigraphs
-Wundef
-Wuninitialized
-Wunreachable-code
-Wunused
-Wno-sign-compare
-Wno-extra-semi-stmt
)
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang")
list(APPEND CMAKE_COMPILER_WARNINGS
-Weverything
-Wno-c++98-compat
-Wno-c++98-compat-pedantic
-Wno-conversion
-Wno-double-promotion
-Wno-exit-time-destructors
-Wno-extra-semi
-Wno-float-conversion
-Wno-gnu-anonymous-struct
-Wno-gnu-zero-variadic-macro-arguments
-Wno-missing-prototypes
-Wno-nested-anon-types
-Wno-padded
-Wno-return-std-move-in-c++11
-Wno-shorten-64-to-32
-Wno-sign-conversion
-Wno-unknown-warning-option
-Wno-unused-command-line-argument
-Wno-weak-vtables
-Wno-covered-switch-default
)
else()
if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX")
# cmake 3.5.2 does not support >=.
if(NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS "6.1")
list(APPEND CMAKE_COMPILER_WARNINGS
-Wno-ignored-attributes)
endif()
endif()
list(APPEND CMAKE_COMPILER_WARNINGS
-Wno-missing-field-initializers
-Wno-deprecated-declarations
)
endif()
add_definitions(${CMAKE_COMPILER_WARNINGS})
endforeach()
endif ()

View File

@@ -0,0 +1,14 @@
#ifndef CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
template <typename GridwiseOp, typename... Xs>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
run_gridwise_operation(Xs... xs)
{
GridwiseOp{}.Run(xs...);
}
#endif

View File

@@ -0,0 +1,272 @@
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// Number of GEMMs = YTilda * XTilda
// GemmM = C
// GemmN = N * HTildaSlice * WTildaSlice
// GemmK = K * YDotSlice * XDotSlice
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t IYTildaValue,
index_t IXTildaValue,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<IYTildaValue>,
Number<IXTildaValue>,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
constexpr auto IYTilda = Number<IYTildaValue>{};
constexpr auto IXTilda = Number<IXTildaValue>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilda);
const auto XDot = math::integer_divide_ceil(X, XTilda);
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
const auto IHTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
const auto IWTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
const auto IHTildaSliceEnd =
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildaSliceEnd =
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
const auto K1 = GemmK1;
const auto K0 = K / K1;
// weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilda),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda),
make_freeze_transform(IXTilda),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
#if 1
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// output tensor
// this add padding check
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilda),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
#if 1
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilda, WTilda),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_freeze_transform(IXTilda),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc,
make_tuple(make_pass_through_transform(C),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,275 @@
#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: out
// B: wei
// C: in
// Number of GEMMs = YTilda * XTilda
// GemmM = N * HTildaSlice * WTildaSlice
// GemmN = C
// GemmK = K * YDotSlice * XDotSlice
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t IYTildaValue,
index_t IXTildaValue,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<IYTildaValue>,
Number<IXTildaValue>,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
constexpr auto IYTilda = Number<IYTildaValue>{};
constexpr auto IXTilda = Number<IXTildaValue>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto YTilda = ConvStrideH / GcdStrideDilationH;
const auto XTilda = ConvStrideW / GcdStrideDilationW;
const auto YDot = math::integer_divide_ceil(Y, YTilda);
const auto XDot = math::integer_divide_ceil(X, XTilda);
const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
const auto IHTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH);
const auto IWTildaSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW);
const auto IHTildaSliceEnd =
math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildaSliceEnd =
math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin;
const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin;
// GemmK is different for each GEMM
const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda);
const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda);
const auto K1 = GemmK1;
const auto K0 = K / K1;
// A: output tensor
// this add padding check
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YDot, HTilda),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, WTilda),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_unmerge_transform(make_tuple(K0, K1))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
#if 1
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(K1)),
make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// B: weight tensor
const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilda),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilda),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(IYTilda),
make_freeze_transform(IXTilda),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
#if 1
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#else
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)),
make_pass_through_transform(C),
make_pass_through_transform(K1)),
make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
#endif
// C: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilda, HTilda),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilda, WTilda),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(IYTilda),
make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice),
make_freeze_transform(IXTilda),
make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<3>{}));
const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
in_n_htildaslice_wtildaslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,257 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
}
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_global_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
}
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1(
const TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(
wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,176 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(
wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
}
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 &&
ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 &&
InRightPadW == 0);
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(
wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,129 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)),
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,129 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,132 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// A: in
// B: wei
// C: out
// GemmM = N * Ho * Wo
// GemmN = K
// GemmK = Y * X * C
template <typename... In,
typename... Wei,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
const TensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const TensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const TensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = N * Ho * Wo;
const auto GemmN = K;
const auto GemmK = Y * X * C;
const auto GemmK0 = GemmK / GemmK1;
// A: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmm_grid_desc =
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
transform_tensor_descriptor(in_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: weight tensor
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// C: output tensor
const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,132 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// GemmM0 = 1
// GemmM1 = K
// GemmN0 = N0
// GemmN1 = (N / N0) * Ho * Wo
// GemmK0 = (C / C0) * Y * X
// GemmK1 = C0
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
typename N0Type,
typename C0Type>
__host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const TensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const TensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const TensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const N0Type& N0,
const C0Type& C0)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto N1 = N / N0;
const auto C1 = C / C0;
// weight tensor
const auto wei_gk0_gm0_gm1_gk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(C0, C1)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor(
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
make_pass_through_transform(N0),
make_merge_transform(make_tuple(N1, Ho, Wo)),
make_pass_through_transform(C0)),
make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_n_k_howo_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo));
const auto out_n0_n1_1_k_howo_grid_desc =
transform_tensor_descriptor(out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(N0),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return make_tuple(
wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,33 @@
#ifndef CK_CLUSTER_DESCRIPTOR_HPP
#define CK_CLUSTER_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
namespace ck {
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type>
__host__ __device__ constexpr auto make_cluster_descriptor(
const Lengths& lengths,
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
{
constexpr index_t ndim_low = Lengths::Size();
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
const auto low_lengths = generate_tuple(
[&](auto idim_low) { return reordered_lengths[idim_low]; }, Number<ndim_low>{});
const auto transform = make_merge_transform(low_lengths);
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
constexpr auto up_dim_new_top_ids = Sequence<0>{};
return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}
} // namespace ck
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,103 @@
#ifndef CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#define CK_MULTI_INDEX_TRANSFORM_HELPER_HPP
#include "common_header.hpp"
#include "multi_index_transform.hpp"
namespace ck {
template <typename LowLength>
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length)
{
return PassThrough<LowLength>{low_length};
}
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
__host__ __device__ constexpr auto
make_pad_transform(const LowLength& low_length,
const LeftPad& left_pad,
const RightPad& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{
return Pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
__host__ __device__ constexpr auto make_left_pad_transform(
const LowLength& low_length,
const LeftPadLength& left_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{
return LeftPad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad};
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck>
__host__ __device__ constexpr auto make_right_pad_transform(
const LowLength& low_length,
const RightPadLength& right_pad,
integral_constant<bool, SkipIsValidCheck> = integral_constant<bool, false>{})
{
return RightPad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad};
}
template <typename UpLengths,
typename Coefficients,
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients)
{
return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
}
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return Merge_v1_carry_check<LowLengths>{low_lengths};
#else
#if 1
return Merge_v2_magic_division<LowLengths>{low_lengths};
#else
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
#endif
#endif
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{
return Merge_v2_magic_division<LowLengths>{low_lengths};
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
__host__ __device__ constexpr auto make_unmerge_transform(
const UpLengths& up_lengths,
integral_constant<bool, Use24BitIntegerCalculation> = integral_constant<bool, false>{})
{
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}
template <typename LowerIndex>
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{
return Freeze<LowerIndex>{low_idx};
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
{
return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
}
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
}
} // namespace ck
#endif

View File

@@ -0,0 +1,464 @@
#ifndef CK_TENSOR_ADAPTOR_HPP
#define CK_TENSOR_ADAPTOR_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template <typename Transforms,
typename LowerDimensionHiddenIdss,
typename UpperDimensionHiddenIdss,
typename BottomDimensionHiddenIds,
typename TopDimensionHiddenIds>
struct TensorAdaptor
{
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; }
__host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss()
{
return LowerDimensionHiddenIdss{};
}
__host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss()
{
return UpperDimensionHiddenIdss{};
}
__host__ __device__ static constexpr auto GetTopDimensionHiddenIds()
{
return TopDimensionHiddenIds{};
}
__host__ __device__ static constexpr auto GetBottomDimensionHiddenIds()
{
return BottomDimensionHiddenIds{};
}
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
return length;
},
Number<ndim_top_>{});
// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
{
constexpr auto idim_top = Number<IDim>{};
constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
}
__host__ __device__ static constexpr index_t GetNumOfBottomDimension()
{
return BottomDimensionHiddenIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfTopDimension()
{
return TopDimensionHiddenIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
LowerDimensionHiddenIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); },
UpperDimensionHiddenIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
math::less<index_t>,
math::equal<index_t>>::type;
return unique_sort_all_dim_ids::Size();
}
constexpr static index_t ntransform_ = GetNumOfTransform();
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension();
constexpr static index_t ndim_top_ = GetNumOfTopDimension();
using HiddenIndex = MultiIndex<ndim_hidden_>;
using BottomIndex = MultiIndex<ndim_bottom_>;
using TopIndex = MultiIndex<ndim_top_>;
// may be index_t or Number<>
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public:
__host__ __device__ constexpr TensorAdaptor() = default;
__host__ __device__ constexpr TensorAdaptor(const Transforms& transforms)
: transforms_{transforms}, element_size_{InitializeElementSize(transforms)}
{
static_assert(Transforms::Size() == ntransform_ &&
LowerDimensionHiddenIdss::Size() == ntransform_ &&
UpperDimensionHiddenIdss::Size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = GetNumOfTransform();
constexpr index_t ndim_hidden = GetNumOfHiddenDimension();
MultiIndex<ndim_hidden> idx_hidden;
// initialize uppest index
set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&](auto itran_p1) {
auto itran = itran_p1 - Number<1>{};
const auto& tran = GetTransforms().At(itran);
constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran);
constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return get_container_subset(idx_hidden, BottomDimensionHiddenIds{});
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &=
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
});
return is_known && is_known_at_compile_time<ElementSize>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("TensorAdaptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionHiddenIds:");
LowerDimensionHiddenIdss{}.At(i).Print();
printf("UpperDimensionHiddenIds:");
UpperDimensionHiddenIdss{}.At(i).Print();
});
printf("BottomDimensionHiddenIds:");
BottomDimensionHiddenIds::Print();
printf("TopDimensionHiddenIds:");
TopDimensionHiddenIds::Print();
printf("}");
}
private:
Transforms transforms_;
ElementSize element_size_;
};
template <typename TensorAdaptor0, typename TensorAdaptor1>
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0,
const TensorAdaptor1& adaptor1)
{
static_assert(TensorAdaptor0::GetNumOfTopDimension() ==
TensorAdaptor1::GetNumOfBottomDimension(),
"wrong!");
// all_transforms = transform0 + transform1
const auto all_transforms =
container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms());
// shift
constexpr index_t adaptor0_max_hidden_id = [&]() {
index_t adaptor0_max_hidden_id_ = NumericLimits<index_t>::Min();
static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension();
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
adaptor0_max_hidden_id_ =
math::max(adaptor0_max_hidden_id_,
TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
});
constexpr index_t ndim_up =
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension();
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor0_max_hidden_id_ =
math::max(adaptor0_max_hidden_id_,
TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
});
});
return adaptor0_max_hidden_id_;
}();
constexpr index_t adaptor1_min_hidden_id = [&]() {
index_t adaptor1_min_hidden_id_ = NumericLimits<index_t>::Max();
static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
constexpr index_t low_dim_hidden_id =
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value;
bool is_bottom_dim = false;
static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) {
if constexpr(low_dim_hidden_id ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[i])
{
is_bottom_dim = true;
}
});
if(!is_bottom_dim)
{
adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id);
}
});
constexpr index_t ndim_up =
TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension();
// get the min of all upper dimensions
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
adaptor1_min_hidden_id_ =
math::min(adaptor1_min_hidden_id_,
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
});
});
return adaptor1_min_hidden_id_;
}();
constexpr index_t adaptor1_hidden_id_shift =
adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id;
constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr auto low_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size();
constexpr auto low_dim_hidden_ids_1 =
TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran];
// sequence in, sequence out
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
{
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
// shift hidden id so every dim id is unique
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
});
// match hidden id
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) {
// if this low dim is bottom dim, then do id matching
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1])
{
low_dim_hidden_ids_1_mod_(idim_low_1) =
TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1];
}
});
});
return low_dim_hidden_ids_1_mod_;
}
();
return generate_sequence_v2(
[&](auto i) constexpr { return Number<low_dim_hidden_ids_1_mod[i]>{}; },
Number<ndim_low_1>{});
},
Number<TensorAdaptor1::GetNumOfTransform()>{});
constexpr auto all_low_dim_hidden_idss =
container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr auto up_dim_hidden_idss_1 = generate_tuple(
// generate sequence of ids for a transform
[&](auto itran) {
constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size();
constexpr auto up_dim_hidden_ids_1 =
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran];
// sequence in, constexpr tuple out
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
{
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
// shift hidden id
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
});
return up_dim_hidden_ids_1_mod_;
}
();
// constexpr tuple to sequence
return generate_sequence_v2(
[&](auto i) constexpr { return Number<up_dim_hidden_ids_1_mod[i]>{}; },
Number<ndim_up_1>{});
},
Number<TensorAdaptor1::GetNumOfTransform()>{});
constexpr auto all_up_dim_hidden_idss =
container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr auto top_dim_hidden_ids =
TensorAdaptor1::GetTopDimensionHiddenIds() + Number<adaptor1_hidden_id_shift>{};
// put everything together
return TensorAdaptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{all_transforms};
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template <typename Transforms, typename LowerDimensionOldTopIdss, typename UpperDimensionNewTopIdss>
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms,
LowerDimensionOldTopIdss,
UpperDimensionNewTopIdss)
{
constexpr index_t ntransform = Transforms::Size();
static_assert(LowerDimensionOldTopIdss::Size() == ntransform &&
UpperDimensionNewTopIdss::Size() == ntransform,
"wrong!");
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr auto all_low_dim_old_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{});
constexpr auto all_up_dim_new_top_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{});
static_assert(is_valid_sequence_map<decltype(all_low_dim_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_up_dim_new_top_ids)>::value,
"wrong!");
constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size();
constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size();
// low_dim_hidden_idss
constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr auto up_dim_hidden_idss = generate_tuple(
[](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number<ndim_old_top>{}; },
Number<ntransform>{});
// bottom_dim_hidden_ids
constexpr auto bottom_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{};
// top_dim_hidden_ids
constexpr auto top_dim_hidden_ids =
typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number<ndim_old_top>{};
return TensorAdaptor<remove_cv_t<Transforms>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(bottom_dim_hidden_ids)>,
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
}
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
}
} // namespace ck
#endif

View File

@@ -0,0 +1,597 @@
#ifndef CK_TENSOR_DESCRIPTOR_HPP
#define CK_TENSOR_DESCRIPTOR_HPP
#include "common_header.hpp"
#include "multi_index_transform.hpp"
namespace ck {
template <index_t NDimHidden, typename VisibleDimensionIds>
struct TensorCoordinate;
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct TensorCoordinateStep;
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
// UpperDimensionIdss : Tuple<Sequence<...>, ...>
// VisibleDimensionIds> : Sequence<...>
template <typename Transforms,
typename LowerDimensionIdss,
typename UpperDimensionIdss,
typename VisibleDimensionIds,
typename ElementSpaceSize>
struct TensorDescriptor
{
// TODO make these private
__host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); }
__host__ __device__ static constexpr index_t GetNumOfVisibleDimension()
{
return VisibleDimensionIds::Size();
}
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
{
constexpr auto all_low_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
constexpr auto all_up_dim_ids = unpack(
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
using unique_sort_all_dim_ids = typename sequence_unique_sort<decltype(all_dim_ids),
math::less<index_t>,
math::equal<index_t>>::type;
return unique_sort_all_dim_ids::Size();
}
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_visible) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible);
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
return length;
},
Number<ndim_visible_>{});
// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
}
template <index_t IDim>
__host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number<IDim>)
{
constexpr auto idim_visible = Number<IDim>{};
constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible);
index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;
static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionIdss{}[itran];
static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
return make_tuple(itran_found, idim_up_found, found);
}
constexpr static index_t ntransform_ = GetNumOfTransform();
constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension();
constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension();
using VisibleIndex = MultiIndex<ndim_visible_>;
using HiddenIndex = MultiIndex<ndim_hidden_>;
using Coordinate = TensorCoordinate<ndim_hidden_, VisibleDimensionIds>;
// may be index_t or Number<>
using ElementSize = remove_cv_t<decltype(InitializeElementSize(Transforms{}))>;
public:
__host__ __device__ constexpr TensorDescriptor() = default;
__host__ __device__ constexpr TensorDescriptor(const Transforms& transforms,
ElementSpaceSize element_space_size)
: transforms_{transforms},
element_size_{InitializeElementSize(transforms)},
element_space_size_{element_space_size}
{
static_assert(Transforms::Size() == ntransform_ &&
LowerDimensionIdss::Size() == ntransform_ &&
UpperDimensionIdss::Size() == ntransform_,
"wrong! inconsistent # of transformations");
// TODO check dependency of dimensions is valid
}
__host__ __device__ static constexpr index_t GetNumOfDimension()
{
return GetNumOfVisibleDimension();
}
template <index_t IDim>
__host__ __device__ constexpr auto GetLength(Number<IDim>) const
{
static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range");
constexpr auto tmp = GetTransformAndItsUpperDimension(Number<IDim>{});
constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];
static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");
return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
}
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
template <typename Idx>
__host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const
{
static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension");
return make_tensor_coordinate(*this, idx).GetOffset();
}
// TODO make these private
__host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; }
__host__ __device__ static constexpr auto GetLowerDimensionIdss()
{
return LowerDimensionIdss{};
}
__host__ __device__ static constexpr auto GetUpperDimensionIdss()
{
return UpperDimensionIdss{};
}
__host__ __device__ static constexpr auto GetVisibleDimensionIds()
{
return VisibleDimensionIds{};
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &=
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
});
return is_known && is_known_at_compile_time<ElementSize>::value &&
is_known_at_compile_time<ElementSpaceSize>::value;
}
__host__ __device__ void Print() const
{
printf("{");
printf("TensorDescriptor, ");
static_for<0, ntransform_, 1>{}([&](auto i) {
printf("transforms: ");
transforms_[i].Print();
printf("LowerDimensionIds:");
LowerDimensionIdss{}.At(i).Print();
printf("UpperDimensionIds:");
UpperDimensionIdss{}.At(i).Print();
});
printf("}");
VisibleDimensionIds::Print();
}
// TODO make these private
Transforms transforms_;
ElementSize element_size_;
ElementSpaceSize element_space_size_;
};
template <index_t NDimHidden, typename VisibleDimensionIds>
struct TensorCoordinate
{
// TODO make these private
static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using VisibleIndex = MultiIndex<ndim_visible_>;
public:
__host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden)
: idx_hidden_{idx_hidden}
{
}
__host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); }
__host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; }
// TODO make these private
__host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; }
__host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; }
__host__ __device__ constexpr auto GetVisibleIndex() const
{
return get_container_subset(idx_hidden_, VisibleDimensionIds{});
}
// TODO make these private
HiddenIndex idx_hidden_;
};
template <index_t NTransform, index_t NDimVisible, typename UpdateLowerIndexHack>
struct TensorCoordinateStep
{
// TODO make these private
using VisibleIndex = MultiIndex<NDimVisible>;
public:
__host__ __device__ constexpr TensorCoordinateStep() = default;
__host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible,
const MultiIndex<NTransform>& do_transforms)
: idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms}
{
}
__host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); }
// TODO make these private
__host__ __device__ constexpr const auto& GetVisibleIndexDiff() const
{
return idx_diff_visible_;
}
VisibleIndex idx_diff_visible_;
MultiIndex<NTransform> do_transforms_;
// HACK: control UpdateLowerIndex()
static constexpr UpdateLowerIndexHack update_lower_index_hack_;
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_descriptor) because template cannot be defined inside a function
// template
template <typename NewTransforms>
struct lambda_get_up_dim_num
{
template <typename I>
__host__ __device__ constexpr auto operator()(I) const
{
using Tran = remove_reference_t<decltype(NewTransforms{}.At(I{}))>;
return Number<Tran::GetNumOfUpperDimension()>{};
}
};
template <typename OldTensorDescriptor,
typename NewTransforms,
typename NewLowerDimensionOldVisibleIdss,
typename NewUpperDimensionNewVisibleIdss>
__host__ __device__ constexpr auto
transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
const NewTransforms& new_transforms,
NewLowerDimensionOldVisibleIdss,
NewUpperDimensionNewVisibleIdss)
{
// sanity check
{
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
"wrong!");
}
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr auto low_dim_hidden_idss = transform_tuples(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](auto low_dim_visible_ids) constexpr {
return transform_sequences(
// convert lower dimension visible id to hidden id
[](auto low_dim_visible_id) constexpr {
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
},
low_dim_visible_ids);
},
NewLowerDimensionOldVisibleIdss{});
constexpr index_t num_new_transform = NewTransforms::Size();
// upper dimension's hidden idss
constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension();
constexpr auto up_dim_numbers =
generate_sequence(lambda_get_up_dim_num<NewTransforms>{}, Number<num_new_transform>{});
constexpr auto up_dim_numbers_scan = merge_sequences(
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
constexpr auto up_dim_hidden_idss = generate_tuple(
[ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr {
return
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
1>::type{};
},
Number<num_new_transform>{});
// new visible dimension's hidden ids
constexpr auto unordered_new_visible_dim_hidden_ids = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
constexpr auto new_visible_dim_unordered2ordered = unpack(
[](auto... xs) constexpr { return merge_sequences(xs...); },
NewUpperDimensionNewVisibleIdss{});
constexpr auto new_visible_dim_hidden_ids =
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);
// put everything together
const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms);
constexpr auto all_low_dim_hidden_idss =
container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss);
constexpr auto all_up_dim_hidden_idss =
container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss);
const auto element_space_size = old_tensor_desc.GetElementSpaceSize();
return TensorDescriptor<remove_cv_t<decltype(all_transforms)>,
remove_cv_t<decltype(all_low_dim_hidden_idss)>,
remove_cv_t<decltype(all_up_dim_hidden_idss)>,
remove_cv_t<decltype(new_visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{all_transforms,
element_space_size};
}
template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const VisibleIndex& idx_visible)
{
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
MultiIndex<ndim_hidden> idx_hidden;
// initialize visible index
set_container_subset(idx_hidden, visible_dim_ids, idx_visible);
// calculate hidden index
static_for<ntransform, 0, -1>{}([&tensor_desc, &idx_hidden](auto itran_p1) {
auto itran = itran_p1 - Number<1>{};
const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up = get_container_subset(idx_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_low;
tran.CalculateLowerIndex(idx_low, idx_up);
set_container_subset(idx_hidden, dims_low, idx_low);
});
return TensorCoordinate<ndim_hidden, decltype(visible_dim_ids)>{idx_hidden};
}
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
template <typename TensorDesc, typename VisibleIndex, typename UpdateLowerIndexHack>
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const VisibleIndex& idx_diff_visible,
UpdateLowerIndexHack)
{
static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(),
"wrong! # of dimension inconsistent");
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!");
// use index_t for boolean type
auto do_transforms = make_zero_multi_index<ntransform>();
auto is_non_zero_diff = make_zero_multi_index<ndim_hidden>();
// decide do_transform by checkout non-zero index diff components
MultiIndex<VisibleIndex::Size()> non_zero_diff_pick_visible;
static_for<0, ndim_visible, 1>{}(
[&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); });
set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible);
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up);
MultiIndex<dims_low.Size()> non_zero_diff_pick_low;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const bool idx_diff_up_has_non_zero = container_reduce(
non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false);
do_transforms(itran) = idx_diff_up_has_non_zero;
static_for<0, dims_low.Size(), 1>{}(
[&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; });
set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low);
});
return TensorCoordinateStep<ntransform, ndim_visible, UpdateLowerIndexHack>{idx_diff_visible,
do_transforms};
}
template <typename TensorDesc, typename VisibleIndex>
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&,
const VisibleIndex& idx_diff_visible)
{
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
return make_tensor_coordinate_step(
TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen<ntransform, 0>::type{});
}
template <typename TensorDesc, typename TensorCoord, typename TensorCoordStep>
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc,
TensorCoord& coord,
const TensorCoordStep& coord_step)
{
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
// this is what needs to be calculated
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
// initialize visible index diff
set_container_subset(
idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff());
// this is what needs to be updated
auto& idx_hidden = coord.GetHiddenIndex();
// update visible index
auto idx_hidden_pick_visible =
get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds());
idx_hidden_pick_visible += coord_step.GetIndexDiff();
set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible);
// update rest of hidden index
static_for<ntransform - 1, -1, -1>{}([&](auto itran) {
if(coord_step.do_transforms_[itran])
{
const auto& tran = tensor_desc.GetTransforms().At(itran);
constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran);
constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran);
const auto idx_up_new = get_container_subset(idx_hidden, dims_up);
auto idx_low = get_container_subset(idx_hidden, dims_low);
const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up);
MultiIndex<dims_low.Size()> idx_diff_low;
// HACK: control UpdateLowerIndex for Merge using hack
constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran);
tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number<Hack>{});
set_container_subset(idx_diff_hidden, dims_low, idx_diff_low);
set_container_subset(idx_hidden, dims_low, idx_low);
}
});
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool
coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
bool valid = true;
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
const auto& idx_hidden = coord.GetHiddenIndex();
static_for<ntransform - 1, -1, -1>{}([&tensor_desc, &idx_hidden, &valid](auto itran) {
const auto tran = tensor_desc.GetTransforms().At(itran);
// check validity, only if current transformation does not always has a valid mapping
if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex())
{
const auto idx_up =
get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
}
});
return valid;
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
// check visible index
const auto& idx_visible = coord.GetVisibleIndex();
bool is_visible_index_valid = true;
static_for<0, TensorDesc::GetNumOfDimension(), 1>{}(
[&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) {
is_visible_index_valid =
is_visible_index_valid &&
(idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i));
});
// check other hidden index
return is_visible_index_valid &&
coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc>
using TensorCoordinate_t = decltype(make_tensor_coordinate(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
template <typename TensorDesc>
using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step(
TensorDesc{}, MultiIndex<remove_cv_t<remove_reference_t<TensorDesc>>::GetNumOfDimension()>{}));
} // namespace ck
#endif

View File

@@ -0,0 +1,149 @@
#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP
#define CK_TENSOR_DESCRIPTOR_HELPER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "multi_index_transform_helper.hpp"
namespace ck {
/*
* These functions create tensor descriptor at runtime. If they are not constexpr, you will
* likely see usage of scratch memory during construction of these tensor descriptors. So
* it's better to call these functions on host and then pass the constructed tensor descritpors
* to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
* functions on GPU without worrying about scratch memory usage.
*/
#if CK_WORKAROUND_SWDEV_275126
template <typename Lengths, typename Strides, index_t I, typename AccOld>
__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths,
const Strides& strides,
Number<I> i,
AccOld acc_old)
{
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
if constexpr(i.value < Lengths::Size() - 1)
{
return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new);
}
else
{
return acc_new;
}
}
#endif
template <typename... Lengths,
typename... Strides,
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides)
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_embed_transform(lengths, strides));
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
#if !CK_WORKAROUND_SWDEV_275126
// rocm-4.1 compiler would crash for recursive labmda
// recursive function for reduction
auto f = [&](auto fs, auto i, auto acc_old) {
auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i];
if constexpr(i.value < N - 1)
{
return fs(fs, i + Number<1>{}, acc_new);
}
else
{
return acc_new;
}
};
const auto element_space_size = f(f, Number<0>{}, Number<1>{});
#else
const auto element_space_size =
calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{});
#endif
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
}
// Lengths... can be:
// 1) index_t, which is known at run-time
// 2) Number<>, which is known at compile-time
template <typename... Lengths>
__host__ __device__ constexpr auto
make_naive_tensor_descriptor_packed(const Tuple<Lengths...>& lengths)
{
constexpr index_t N = sizeof...(Lengths);
const auto transforms = make_tuple(make_unmerge_transform(lengths));
constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{});
constexpr auto up_dim_hidden_idss =
make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{});
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{});
return TensorDescriptor<remove_cv_t<decltype(transforms)>,
remove_cv_t<decltype(low_dim_hidden_idss)>,
remove_cv_t<decltype(up_dim_hidden_idss)>,
remove_cv_t<decltype(visible_dim_hidden_ids)>,
remove_cv_t<decltype(element_space_size)>>{transforms,
element_space_size};
}
template <typename... Lengths, typename Align>
__host__ __device__ constexpr auto
make_naive_tensor_descriptor_aligned(const Tuple<Lengths...>& lengths, Align align)
{
constexpr auto I1 = Number<1>{};
constexpr index_t N = sizeof...(Lengths);
const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number<N - 1>{}], align);
auto strides = generate_tuple(
[&](auto i) {
if constexpr(i.value == N - 1)
{
return I1;
}
else if constexpr(i.value == N - 2)
{
return Number<stride_n_minus_2>{};
}
else
{
return container_reduce(lengths,
math::multiplies{},
Number<stride_n_minus_2>{},
i + I1,
Number<N - 1>{},
I1);
}
},
Number<N>{});
return make_naive_tensor_descriptor(lengths, strides);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,394 @@
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace ck {
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. AKMBlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BKNBlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <
index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename AKMBlockDesc,
typename BKNBlockDesc,
index_t M1PerThreadM11,
index_t N1PerThreadN11,
index_t KPerThread,
index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11,
typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t K = AKMBlockDesc{}.GetLength(I0);
static constexpr index_t M = AKMBlockDesc{}.GetLength(I1);
static constexpr index_t N = BKNBlockDesc{}.GetLength(I1);
static constexpr index_t M100 = M1N1ThreadClusterM100;
static constexpr index_t N100 = M1N1ThreadClusterN100;
static constexpr index_t M101 = M1N1ThreadClusterM101;
static constexpr index_t N101 = M1N1ThreadClusterN101;
static constexpr index_t M11 = M1PerThreadM11;
static constexpr index_t N11 = N1PerThreadN11;
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
static constexpr index_t M0 = M / M1;
static constexpr index_t N0 = N / N1;
__host__ __device__ static constexpr auto
MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */)
{
const auto a_k_m0_m1_block_desc = transform_tensor_descriptor(
AKMBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_block_desc;
}
__host__ __device__ static constexpr auto
MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */)
{
const auto b_k_n0_n1_block_desc = transform_tensor_descriptor(
BKNBlockDesc{},
make_tuple(make_pass_through_transform(Number<K>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_block_desc;
}
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_unmerge_transform(make_tuple(
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
}
__host__ __device__ static constexpr auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<M0>{}),
make_unmerge_transform(
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_pass_through_transform(Number<N0>{}),
make_unmerge_transform(
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
}
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
{
return Sequence<M0, M11, N0, N11>{};
}
static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{});
static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
public:
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
{
static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == M101 * M100 * N101 * N100,
"wrong! blocksize and cluster size not consistent");
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2, "wrong");
}
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
{
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0),
make_pass_through_transform(M11),
make_pass_through_transform(N0),
make_pass_through_transform(N11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
}
__host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
__host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; }
template <typename CM0M1N0N1ThreadDesc,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
a_k_m0_m1_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
b_k_n0_n1_thread_desc_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_k_m0_m1_thread_desc_),
decltype(b_k_n0_n1_thread_desc_),
CM0M1N0N1ThreadDesc,
Sequence<KPerThread>,
Sequence<1, M1PerThreadM11>,
Sequence<1, N1PerThreadN11>>{};
// read A_sub_0
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(I0, I0, I0),
a_block_buf,
a_k_m0_m1_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(I0, I0, I0),
b_block_buf,
b_k_n0_n1_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(I0, I1, I0),
b_block_buf,
b_k_n0_n1_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(I0, I1, I0),
a_block_buf,
a_k_m0_m1_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(k, I0, I0),
a_block_buf,
a_k_m0_m1_thread_desc_,
make_tuple(I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(k, I0, I0),
b_block_buf,
b_k_n0_n1_thread_desc_,
make_tuple(I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(b_k_n0_n1_block_desc_,
make_tuple(k, I1, I0),
b_block_buf,
b_k_n0_n1_thread_desc_,
make_tuple(I0, I1, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
make_tuple(k, I1, I0),
a_block_buf,
a_k_m0_m1_thread_desc_,
make_tuple(I0, I1, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0),
b_thread_buf,
make_tuple(I0, I1, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
// A[K, M0, M1]
static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}));
// B[K, N0, N1]
static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
decltype(a_k_m0_m1_block_desc_),
decltype(a_k_m0_m1_thread_desc_),
Sequence<KPerThread, 1, M1PerThreadM11>,
Sequence<0, 1, 2>,
2,
AThreadCopyScalarPerVector_M11,
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
FloatB,
decltype(b_k_n0_n1_block_desc_),
decltype(b_k_n0_n1_thread_desc_),
Sequence<KPerThread, 1, N1PerThreadN11>,
Sequence<0, 1, 2>,
2,
BThreadCopyScalarPerVector_N11,
1>;
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,410 @@
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace ck {
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename ABlockDesc_BK0_BM_BK1,
typename BBlockDesc_BK0_BN_BK1,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11,
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
using CIndex = MultiIndex<4>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
static constexpr index_t BM11 = BM1PerThreadBM11;
static constexpr index_t BN11 = BN1PerThreadBN11;
static constexpr index_t BM1 = BM100 * BM101 * BM11;
static constexpr index_t BN1 = BN100 * BN101 * BN11;
static constexpr index_t BM0 = BM / BM1;
static constexpr index_t BN0 = BN / BN1;
__host__ __device__ static constexpr auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{
const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor(
a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_block_bk0_bm0_bm1_bk1;
}
__host__ __device__ static constexpr auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor(
b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_block_desc_bk0_bn0_bn1_bk1;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_unmerge_transform(make_tuple(
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
}
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<BM0>{}),
make_unmerge_transform(
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_pass_through_transform(Number<BN0>{}),
make_unmerge_transform(
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
}
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<BM0, BM11, BN0, BN11>{};
}
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public:
__device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
{
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
"wrong! blocksize and cluster size not consistent");
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO remove this restriction
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
BM10BN10ThreadClusterBN10Xs::Size() == 2,
"wrong!");
// TODO: remove this restriction
static_assert(BM0 == 2 && BN0 == 2, "wrong");
}
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr auto adaptor0 =
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
make_pass_through_transform(BM0),
make_pass_through_transform(BM11),
make_pass_through_transform(BN0),
make_pass_through_transform(BN11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
}
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(BM0 == 2 && BN0 == 2 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction =
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
CThreadDesc_BM0_BM11_BN0_BN11,
Sequence<BK0PerThread, BK1>,
Sequence<1, BM1PerThreadBM11>,
Sequence<1, BN1PerThreadBN11>>{};
// read A_sub_0
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of bk0
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
// read A_sub_0
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, I0, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, I0, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, I1, I0, I0),
b_block_buf,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, I1, I0, I0),
a_block_buf,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
// A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
// B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
make_naive_tensor_descriptor_packed(make_tuple(
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,185 @@
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "threadwise_gemm_dlops_v3.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t KPerThread,
index_t HPerThread,
index_t WPerThread,
index_t EPerThreadLoop,
index_t ThreadGemmADataPerRead_K,
index_t ThreadGemmBDataPerRead_W>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
struct MatrixIndex
{
index_t k;
index_t h;
index_t w;
};
// HACK: fix this @Jing Zhang
static constexpr index_t KPerThreadSubC = 4;
static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<EPerThreadLoop>{}, Number<KPerThreadSubC>{}));
static constexpr auto b_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerThreadLoop>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThreadSubC>{}, Number<1>{}, Number<HPerThread>{}, Number<WPerThread>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
FloatA,
BlockMatrixA,
decltype(a_thread_mtx_),
Sequence<EPerThreadLoop, KPerThreadSubC>,
Sequence<0, 1>,
1,
ThreadGemmADataPerRead_K,
1>;
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
{
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t H = BlockMatrixB{}.GetLength(I2);
constexpr index_t W = BlockMatrixB{}.GetLength(I3);
static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0,
"wrong! Cannot evenly divide work among\n");
constexpr auto KThreadCluster = K / KPerThread;
constexpr auto HThreadCluster = H / HPerThread;
constexpr auto WThreadCluster = W / WPerThread;
static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
"wrong! wrong blocksize\n");
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
return Sequence<KPerThread, 1, HPerThread, WPerThread>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{});
constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{});
constexpr auto num_w_threads = W / WPerThread;
constexpr auto num_h_threads = H / HPerThread;
constexpr auto num_hw_threads = num_w_threads * num_h_threads;
index_t k_thread_id = thread_id / num_hw_threads;
index_t hw_thread_id = thread_id % num_hw_threads;
index_t h_thread_id = hw_thread_id / num_w_threads;
index_t w_thread_id = hw_thread_id % num_w_threads;
return MatrixIndex{k_thread_id, h_thread_id, w_thread_id};
}
template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BThreadBuffer& b_thread_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABlockBuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CThreadBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto EPerBlock = a_block_mtx.GetLength(I0);
// HACK: fix this @Jing Zhang
constexpr auto HoPerThreadSubC = 2;
constexpr auto WoPerThreadSubC = 2;
static_assert(KPerThread % KPerThreadSubC == 0, "");
static_assert(HPerThread % HoPerThreadSubC == 0, "");
static_assert(WPerThread % WoPerThreadSubC == 0, "");
// thread A buffer for GEMM
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
a_thread_buf;
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
FloatB,
FloatC,
decltype(a_thread_mtx_),
decltype(b_thread_mtx_),
decltype(c_thread_mtx_),
HoPerThreadSubC,
WoPerThreadSubC>{};
static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) {
a_thread_copy_.Run(a_block_mtx,
make_tuple(e_begin, k_begin),
a_block_buf,
a_thread_mtx_,
make_tuple(I0, I0),
a_thread_buf);
static_for<0, HPerThread, HoPerThreadSubC>{}([&](auto h_begin) {
static_for<0, WPerThread, WoPerThreadSubC>{}([&](auto w_begin) {
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0),
b_thread_buf,
make_tuple(e_begin, I0, h_begin, w_begin),
c_thread_buf,
make_tuple(k_begin, I0, h_begin, w_begin));
});
});
});
});
}
template <typename ABlockSliceMoveStepIdx>
__device__ void MoveASliceWindow(const BlockMatrixA&,
const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
{
a_thread_copy_.MoveSrcSliceWindow(BlockMatrixA{}, a_block_slice_move_step_idx);
}
private:
MatrixIndex c_thread_begin_mtx_idx_;
AThreadCopy a_thread_copy_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,524 @@
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
index_t NPerWave,
index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
__device__ static auto CalculateAThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, m_offset, 0);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset, 0);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset, 0);
}
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static CIndex
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{m_offset, n_offset};
}
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
vector_type<FloatAB, a_thread_desc_.GetElementSpaceSize()> a_thread_vec;
vector_type<FloatAB, b_thread_desc_.GetElementSpaceSize()> b_thread_vec;
static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
using mfma_input_type =
typename vector_type<FloatAB, xdlops_gemm.mfma_type.k_base>::type;
static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(Number<i>{}) = a_thread_buf[Number<i>{}];
});
static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) {
b_thread_vec.template AsType<FloatAB>()(Number<i>{}) = b_thread_buf[Number<i>{}];
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
m0,
n0>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf);
});
});
});
}
private:
// A[K, M]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<1, MRepeat, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
K1,
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<1, NRepeat, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
K1,
1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
template <index_t BlockSize,
typename FloatAB,
class ABlockDesc,
class BBlockDesc,
index_t MPerWave,
index_t NPerWave,
index_t K1>
struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
{
using CIndex = MultiIndex<2>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto xdlops_gemm = XdlopsGemm<float, MPerWave, NPerWave, K1>{};
static constexpr index_t WaveSize = 64;
static constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
static constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
static constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
static constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
static constexpr index_t MWaves = M1 / MPerWave;
static constexpr index_t NWaves = N1 / NPerWave;
static constexpr index_t MRepeat = M0;
static constexpr index_t NRepeat = N0;
__device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); }
__device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); }
__device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); }
__device__ static auto CalculateAThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_m = waveId / NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, m_offset, 0);
}
else
{
const index_t m_offset = waveId_m * MPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, m_offset, 0);
}
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const index_t thread_id = get_thread_local_1d_id();
const index_t waveId = thread_id / WaveSize;
const index_t laneId = thread_id % WaveSize;
const index_t waveId_n = waveId % NWaves;
if constexpr(xdlops_gemm.IsKReduction)
{
const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId);
const index_t k_offset = xdlops_gemm.GetBlkId(laneId);
return make_tuple(k_offset, 0, n_offset, 0);
}
else
{
const index_t n_offset = waveId_n * NPerWave + laneId;
const index_t k_offset = 0;
return make_tuple(k_offset, 0, n_offset, 0);
}
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static CIndex
CalculateCThreadOriginDataIndex(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const index_t waveId = get_thread_local_1d_id() / WaveSize;
const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
const index_t waveId_m = waveId / NWaves;
const index_t waveId_n = waveId % NWaves;
const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0];
const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1];
return CIndex{m_offset, n_offset};
}
__device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline()
: a_thread_copy_{CalculateAThreadOriginDataIndex()},
b_thread_copy_{CalculateBThreadOriginDataIndex()}
{
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
"wrong! K dimension not consistent");
static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3),
"wrong! K1 dimension not consistent");
static_assert(BlockSize == MWaves * NWaves * WaveSize,
"BlockSize != MWaves * NWaves * WaveSize\n");
static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!");
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!");
static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!");
}
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
static_for<xdlops_gemm.KPerXdlops, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) {
// read A_sub_0
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_0
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
// read B_sub_1
b_thread_copy_.Run(BBlockDesc{},
make_tuple(k, I1, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(ABlockDesc{},
make_tuple(k, I1, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
0,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
0>(a_thread_buf, b_thread_buf, c_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
xdlops_gemm.template Run<decltype(a_thread_desc_),
decltype(b_thread_desc_),
decltype(c_thread_desc_),
1,
1>(a_thread_buf, b_thread_buf, c_thread_buf);
}
private:
// A[K, M]
static constexpr auto a_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<MRepeat>{}, I1, Number<K1>{}));
// B[K, N]
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, Number<NRepeat>{}, I1, Number<K1>{}));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
ABlockDesc,
decltype(a_thread_desc_),
Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
1, // K1,
1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatAB,
FloatAB,
BBlockDesc,
decltype(b_thread_desc_),
Sequence<1, 1, 1, K1>,
Sequence<0, 1, 2, 3>,
3,
1, // K1,
1>;
AThreadCopy a_thread_copy_;
BThreadCopy b_thread_copy_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,170 @@
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v4
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename SrcStepHacks>
__device__ void
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
}
}
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf);
}
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
// SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_step_hack);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcScalarPerVector,
DstScalarPerVector,
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,156 @@
#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "cluster_descriptor.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
namespace ck {
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template <index_t BlockSize,
InMemoryDataOperationEnum_t DstInMemOp,
typename BlockSliceLengths,
typename ThreadSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
typename SrcVectorTensorLengths,
typename DstVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder,
typename DstVectorTensorContiguousDimOrder,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
struct BlockwiseTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
__device__ constexpr BlockwiseTensorSliceTransfer_v4r1(const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin)
: threadwise_transfer_(
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
{
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_assert(
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
"wrong! BlockSize too small");
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
}
template <typename SrcBuffer, typename SrcStepHacks>
__device__ void
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
}
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
}
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
}
}
// SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_step_hack);
}
}
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r1<ThreadSliceLengths,
DstInMemOp,
SrcData,
DstData,
SrcDesc,
DstDesc,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorTensorLengths,
DstVectorTensorLengths,
SrcVectorTensorContiguousDimOrder,
DstVectorTensorContiguousDimOrder,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,659 @@
#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseContraction,
typename FloatAB,
typename FloatC,
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
typename CGridBlockCluster_BlockId_To_GM10_GN10,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_contraction_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
{
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t GK0PerBlock,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs,
typename BM10BN10ThreadClusterBN10Xs,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// lds max alignment
// TODO: part of them should be moved into blockwise-gemm
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = GK1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr index_t GM11 = GM1PerBlockGM11;
constexpr index_t GN11 = GN1PerBlockGN11;
const index_t GM10 = GM1 / GM11;
const index_t GN10 = GN1 / GN11;
const index_t grid_size = GM10 * GN10;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
{
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
{
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
{
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11;
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor(
a_grid_desc_gk0_gm0_gm1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
}
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
{
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11;
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor(
b_grid_desc_gk0_gn0_gn1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11)),
make_pass_through_transform(GK1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
}
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
constexpr auto BM = GM0 * GM11;
constexpr auto BN = GN0 * GN11;
constexpr auto BM1 =
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies{}, I1) *
BM1PerThreadBM11>{};
constexpr auto BN1 =
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies{}, I1) *
BN1PerThreadBN11>{};
constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor(
c_grid_desc_gm0_gm1_gn0_gn1,
make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_merge_transform(make_tuple(GM0, GM11)),
make_pass_through_transform(GN10),
make_merge_transform(make_tuple(GN0, GN11))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)),
make_pass_through_transform(GN10),
make_unmerge_transform(make_tuple(BN0, BN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
}
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_grid_block_cluster_blockid_to_gm10_gn10;
}
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
// divide block work by [GM10, GN10]
const auto c_gm10_gn10_block_cluster_idx =
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
// lds max alignment
// TODO: part of them should be moved into blockwise-gemm
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = GK1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned(
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, igm10, 0, 0),
a_block_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, ign10, 0, 0),
b_block_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_gk0_bm_gk1),
decltype(b_block_desc_gk0_bn_gk1),
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
BM1PerThreadBM11,
BN1PerThreadBN11>{};
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t gk0_block_on_grid = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
gk0_block_on_grid += 2 * GK0PerBlock;
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
I1,
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
Sequence<1,
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
1,
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_multi_index(igm10,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
ign10,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_buf,
CGridStepHacks{});
}
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,662 @@
#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r2.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AKM0M1GridDesc,
typename BKN0N1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r2(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AKM0M1GridDesc a_k_m0_m1_grid_desc,
const BKN0N1GridDesc b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AKM0M1GridDesc,
typename BKN0N1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r2(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m0_m1_grid_desc = *reinterpret_cast<const AKM0M1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_k_m0_m1_grid_desc));
const auto b_k_n0_n1_grid_desc = *reinterpret_cast<const BKN0N1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k_n0_n1_grid_desc));
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AKMGridDesc,
typename BKNGridDesc,
typename CMNGridDesc,
index_t MPerBlockM1,
index_t NPerBlockN1,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_K_M0_M1,
typename ABlockTransferThreadClusterLengths_K_M0_M1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N0_N1,
typename BBlockTransferThreadClusterLengths_K_N0_N1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
Number<BBlockTransferDstScalarPerVector_N1>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc,
const BKNGridDesc& b_k_n_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto K = a_k_m_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K == b_k_n_grid_desc.GetLength(I0)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K)
{
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc)
{
const auto K = a_k_m_grid_desc.GetLength(I0);
const auto M = a_k_m_grid_desc.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1;
const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor(
a_k_m_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return a_k_m0_m1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc)
{
const auto K = b_k_n_grid_desc.GetLength(I0);
const auto N = b_k_n_grid_desc.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1;
const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor(
b_k_n_grid_desc,
make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
return b_k_n0_n1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
constexpr auto N11 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{}));
using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AKM0M1GridDesc& a_k_m0_m1_grid_desc,
const BKN0N1GridDesc& b_k_n0_n1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M1>{},
Number<BBlockTransferDstScalarPerVector_N1>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}), max_lds_align);
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1>,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m0_m1_grid_desc),
decltype(a_k_m0_m1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_k_m0_m1_grid_desc,
make_multi_index(0, im0, 0),
a_k_m0_m1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1>,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n0_n1_grid_desc),
decltype(b_k_n0_n1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_k_n0_n1_grid_desc,
make_multi_index(0, in0, 0),
b_k_n0_n1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m_block_desc),
decltype(b_k_n_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{};
constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m0_m1_global_move_slice_window_step_hack =
AGridMoveSliceWindowStepHacks{};
constexpr auto b_k_n0_n1_global_move_slice_window_step_hack =
BGridMoveSliceWindowStepHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_k_m0_m1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_k_n0_n1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_step_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks);
b_blockwise_copy.RunRead(
b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf,
CGridStepHacks{});
}
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,650 @@
#ifndef CK_GRIDWISE_GEMM_V1R3_HPP
#define CK_GRIDWISE_GEMM_V1R3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0M0M1K1GridDesc,
typename BK0N0N1K1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by CONSTANT void pointer
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0M0M1K1GridDesc,
typename BK0N0N1K1GridDesc,
typename CM0M10M11N0N10N11GridDesc,
typename CBlockIdToM0N0BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dlops_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k0_m0_m1_k1_grid_desc = *reinterpret_cast<const AK0M0M1K1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_k0_m0_m1_k1_grid_desc));
const auto b_k0_n0_n1_k1_grid_desc = *reinterpret_cast<const BK0N0N1K1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k0_n0_n1_k1_grid_desc));
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_m10_m11_n0_n10_n11_grid_desc));
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_m0_n0_block_cluster_adaptor));
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m0_m1_k1_grid_desc,
b_k0_n0_n1_k1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlockM1,
index_t NPerBlockN1,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
typename M11N11ThreadClusterM110Xs,
typename M11N11ThreadClusterN110Xs,
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v1r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
{
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
{
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto M1 = Number<MPerBlockM1>{};
const auto M0 = M / M1;
const auto a_k0_m0_m1_k1_grid_desc =
transform_tensor_descriptor(a_k0_m_k1_grid_desc,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
{
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto N1 = Number<NPerBlockN1>{};
const auto N0 = N / N1;
const auto b_k0_n0_n1_k1_grid_desc =
transform_tensor_descriptor(b_k0_n_k1_grid_desc,
make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
constexpr auto M11 =
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) *
M1PerThreadM111>{};
constexpr auto N11 =
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) *
N1PerThreadN111>{};
constexpr auto M10 = M1 / M11;
constexpr auto N10 = N1 / N11;
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
return c_m0_m10_m11_n0_n10_n11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlockM1>{};
constexpr auto N1 = Number<NPerBlockN1>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{}));
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
// divide block work by [M, N]
const auto c_m0_n0_block_cluster_idx =
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() ==
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m0_m1_k1_grid_desc),
decltype(a_k0_m0_m1_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(a_k0_m0_m1_k1_grid_desc,
make_multi_index(0, im0, 0, 0),
a_k0_m0_m1_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n0_n1_k1_grid_desc),
decltype(b_k0_n0_n1_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
false,
true>(b_k0_n0_n1_k1_grid_desc,
make_multi_index(0, in0, 0, 0),
b_k0_n0_n1_k1_block_desc,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM110Xs,
M11N11ThreadClusterN110Xs,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed(
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_m10_m11_n10_n11_thread_desc),
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
.Run(c_m10_m11_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0);
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
a_block_slice_copy_step,
AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
b_block_slice_copy_step,
BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K0 - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{});
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{});
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{});
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{});
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
I1,
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
1,
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
make_multi_index(im0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_grid_buf,
CGridStepHacks{});
}
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,458 @@
#ifndef CK_GRIDWISE_GEMM_V2_HPP
#define CK_GRIDWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "blockwise_gemm_dlops_v3.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
index_t KPerBlock,
index_t HoPerBlock,
index_t WoPerBlock,
index_t EPerBlock,
index_t KPerThread,
index_t HoPerThread,
index_t WoPerThread,
index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalStepHacks,
typename BGlobalStepHacks,
typename CGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks>
struct GridwiseGemmDlops_km_kn_mn_v3
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto E = EPerBlock * 3 * 3;
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align);
return a_block_space_size * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
constexpr auto E = EPerBlock * 3 * 3;
// const auto E = a_e_k_global_desc.GetLength(I0);
const auto K = a_e_k_global_desc.GetLength(I1);
const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
// divide block work by [M, N]
#if 0
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
const auto wo_block_work_num = Wo / Number<WoPerBlock>{};
const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num;
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num;
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#else
// Hack: this force result into SGPR
const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock);
const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock);
const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num;
const index_t k_block_work_id =
__builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num);
const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num;
const index_t ho_block_work_id =
__builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num);
const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num;
#endif
// lds max alignment
constexpr auto max_lds_align =
math::lcm(Number<ABlockTransferDstScalarPerVector_K>{}, Number<KPerBlock>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<EPerBlock>{}, Number<KPerBlock>{}), max_lds_align);
constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<E>{}, Number<KPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerBlock>{}, Number<WoPerBlock>{}));
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_e_k_block_desc),
decltype(b_e_n_ho_wo_block_desc),
decltype(c_k_n_ho_wo_thread_desc),
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{};
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k;
const auto ho_thread_id = c_thread_mtx_index.h;
const auto wo_thread_id = c_thread_mtx_index.w;
const index_t k_block_data_on_global = k_block_work_id * KPerBlock;
const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock;
const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock;
const index_t ho_thread_data_on_global =
ho_block_data_on_global + ho_thread_id * HoPerThread;
const index_t wo_thread_data_on_global =
wo_block_data_on_global + wo_thread_id * WoPerThread;
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<E, KPerBlock>,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_e_k_global_desc),
decltype(a_e_k_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_e_k_global_desc,
make_multi_index(0, k_block_data_on_global),
a_e_k_desc,
make_multi_index(0, 0));
constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<EPerBlock>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
auto b_threadwise_transfer =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
decltype(b_e_n_ho_wo_global_desc),
decltype(b_e_n_ho_wo_thread_desc),
Sequence<EPerBlock, 1, HoPerThread, WoPerThread>,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
1,
true>(
b_e_n_ho_wo_global_desc,
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_shared_block, a_e_k_desc.GetElementSpaceSize());
// register allocation for output
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAcc,
c_k_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
c_thread_buf;
// initialize output thread tensor
ThreadwiseTensorSliceSet_v1<FloatAcc,
decltype(c_k_n_ho_wo_thread_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>>{}
.Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{};
constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{};
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatAB,
b_e_n_ho_wo_thread_desc.GetElementSpaceSize(),
true>
b_thread_even_buf, b_thread_odd_buf;
// LDS double buffer: preload data
{
a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks);
a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf);
}
__syncthreads();
if constexpr(HasMainKBlockLoop)
{
index_t e_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_even_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
e_block_data_begin += 2 * EPerBlock;
} while(e_block_data_begin < E - 2 * EPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc,
b_thread_slice_copy_step);
b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc,
b_global_buf,
b_e_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
b_thread_odd_buf,
b_e_n_ho_wo_global_step_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0));
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf);
}
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{};
const index_t k_thread_data_on_global =
k_block_data_on_global + k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC,
decltype(c_k_n_ho_wo_thread_desc),
decltype(c_k_n_ho_wo_global_desc),
Sequence<KPerThread, 1, HoPerThread, WoPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(
c_k_n_ho_wo_global_desc,
make_multi_index(
k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global))
.Run(c_k_n_ho_wo_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_k_n_ho_wo_global_desc,
c_global_buf,
c_k_n_ho_wo_global_tensor_step_hacks);
}
}
// pass tensor descriptor by reference
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc& a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc& b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc& c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by their pointers
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const AGlobalDesc* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const BGlobalDesc* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const CGlobalDesc* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *p_a_e_k_global_desc;
const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc;
const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc;
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
// pass tensor descriptors by void*
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ void Run(const void* p_a_e_k_global_desc,
const FloatAB* __restrict__ p_a_global,
const void* p_b_e_n_ho_wo_global_desc,
const FloatAB* __restrict__ p_b_global,
const void* p_c_k_n_ho_wo_global_desc,
FloatC* __restrict__ p_c_global,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>) const
{
const auto a_e_k_global_desc = *reinterpret_cast<const AGlobalDesc*>(p_a_e_k_global_desc);
const auto b_e_n_ho_wo_global_desc =
*reinterpret_cast<const BGlobalDesc*>(p_b_e_n_ho_wo_global_desc);
const auto c_k_n_ho_wo_global_desc =
*reinterpret_cast<const CGlobalDesc*>(p_c_k_n_ho_wo_global_desc);
Run(a_e_k_global_desc,
p_a_global,
b_e_n_ho_wo_global_desc,
p_b_global,
c_k_n_ho_wo_global_desc,
p_c_global,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,799 @@
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CM0M1M2NGridDesc,
typename CBlockClusterAdaptor>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
const auto a_k0_m_k1_grid_desc = *reinterpret_cast<const AK0MK1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast<const CM0M1M2NGridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc));
const auto c_block_cluster_adaptor = *reinterpret_cast<const CBlockClusterAdaptor*>(
cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor));
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_block_cluster_adaptor);
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename CMNGridDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerWave,
index_t NPerWave,
index_t K1Value,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadSliceLengths_K0_M_K1,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_K1,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K0_N_K1,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_K1,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
bool CAccessOrderMRepeatNRepeat>
struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align);
return (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CMNGridDesc& c_m_n_grid_desc)
{
// TODO: turn on this
static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
"wrong! K1 need to be known at compile-time");
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
K1 == a_k0_m_k1_grid_desc.GetLength(I2) &&
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) &&
(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0);
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
return grid_size;
}
__host__ __device__ static constexpr auto
MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
{
constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerWave, NPerWave, K1>{};
constexpr auto CLayout = xdlops_gemm.GetCLayout();
constexpr auto M0 = Number<CLayout.M1()>{};
constexpr auto M1 = Number<CLayout.N1()>{};
constexpr auto M2 = Number<CLayout.M0()>{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
constexpr auto N1 = Number<CLayout.N0()>{};
const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor(
c_m_n_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)),
make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{}));
return c_m0_m1_m2_n_grid_desc;
}
__host__ __device__ static constexpr auto
MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
{
const auto M = c_m_n_grid_desc.GetLength(I0);
const auto N = c_m_n_grid_desc.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
const auto M0 = M / M1;
const auto N0 = N / N1;
#if 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
#elif 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))),
make_tuple(Sequence<1, 0>{}),
make_tuple(Sequence<0>{}));
#endif
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = K1;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, MPerBlock, K1>,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k0_m_k1_grid_desc),
decltype(a_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_k0_m_k1_grid_desc,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_k0_m_k1_block_desc,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperationEnum_t::Set,
Sequence<KPerBlock, NPerBlock, K1>,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k0_n_k1_grid_desc),
decltype(b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<1, 0, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_k0_n_k1_grid_desc,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_k0_n_k1_block_desc,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(MPerBlock % (MPerWave * MRepeat) == 0 &&
NPerBlock % (NPerWave * NRepeat) == 0,
"wrong!");
constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MPerBlock / MRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NPerBlock / NRepeat>{})),
make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
const auto blockwise_gemm =
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1<BlockSize,
FloatAB,
decltype(a_k0_m0_m1_k1_block_desc),
decltype(b_k0_n0_n1_k1_block_desc),
MPerWave,
NPerWave,
K1>{};
constexpr auto CLayout = blockwise_gemm.GetCLayout();
constexpr index_t BlkSize = CLayout.GetBlkSize();
constexpr index_t NumBlks = CLayout.GetNumBlks();
constexpr index_t NumXdlops = CLayout.GetNumXdlops();
static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only");
constexpr auto c_mr_nr_blk_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr,
vector_type<FloatAcc, BlkSize>,
c_mr_nr_blk_desc.GetElementSpaceSize(),
true>
c_thread_buf;
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block = p_shared_block;
FloatAB* p_b_block = p_shared_block + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{};
constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{};
constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{};
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
// preload data into LDS
{
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
}
// main body
index_t k_block_data_begin = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc,
a_block_slice_copy_step,
a_k0_m_k1_grid_move_slice_window_step_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc,
b_block_slice_copy_step,
b_k0_n_k1_grid_move_slice_window_step_hack);
a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks);
block_sync_lds();
b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf);
k_block_data_begin += KPerBlock;
} while(k_block_data_begin < (K0 - KPerBlock));
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#if 0
// output: register to global memory
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr index_t N0 = CLayout.N1();
constexpr index_t N1 = CLayout.N0();
constexpr auto c_m0_m1_m2_n_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<1>{},
Number<1>{},
Number<M0>{},
Number<1>{},
Number<M2>{},
Number<1>{}));
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize(), true>
c_blk_buf_;
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
static_for<0, NRepeat, 1>{}([&](auto nr_i) {
constexpr auto blk_off =
c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i));
static_for<0, BlkSize, 1>{}([&](auto j) {
c_blk_buf_(Number<blk_off * BlkSize + j>{}) =
c_thread_buf[Number<blk_off>{}]
.template AsType<FloatAcc>()[Number<j>{}];
});
});
});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat);
constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat);
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<MRepeat, NRepeat, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves),
n_thread_data_on_grid / (N1 * NWaves),
m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0),
n_thread_data_on_grid % (N1 * NWaves) / N1,
m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid % N1)}
.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_blk_buf_,
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
}
#else
{
constexpr index_t M0 = CLayout.M1();
constexpr index_t M1 = CLayout.N1();
constexpr index_t M2 = CLayout.M0();
constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_grid =
m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_grid =
n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{};
auto c_thread_copy =
ThreadwiseTensorSliceTransfer_v1r3<FloatC,
FloatC,
decltype(c_m0_m1_m2_n_thread_desc),
decltype(c_m0_m1_m2_n_grid_desc),
Sequence<1, 1, 1, 1, M0, 1, M2, 1>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_m2_n_grid_desc,
make_multi_index(0,
0,
0,
0,
m_thread_data_on_grid / (M2 * M1),
m_thread_data_on_grid % (M2 * M1) / M2,
m_thread_data_on_grid % M2,
n_thread_data_on_grid)};
auto init_copy = [&](auto c_thread_idx_) {
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
return c_thread_idx_;
};
auto mrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
};
auto nrepeat_plus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
};
auto mrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
};
auto nrepeat_minus_copy = [&](auto c_thread_idx_) {
constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0);
c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus);
constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_);
c_thread_copy.Run(c_m0_m1_m2_n_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
c_thread_buf[Number<blk_off>{}].template AsType<FloatAcc>(),
c_m0_m1_m2_n_grid_desc,
c_grid_buf,
c_m0_m1_m2_n_grid_tensor_step_hacks);
};
static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or
(MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or
(MRepeat == 1 && NRepeat == 1),
"wrong");
if constexpr(MRepeat == 4 && NRepeat == 4)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
nrepeat_plus_copy(make_tuple(I2, I2));
nrepeat_plus_copy(make_tuple(I2, I3));
mrepeat_plus_copy(make_tuple(I3, I3));
nrepeat_minus_copy(make_tuple(I3, I2));
nrepeat_minus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
mrepeat_plus_copy(make_tuple(I2, I2));
mrepeat_plus_copy(make_tuple(I3, I2));
nrepeat_plus_copy(make_tuple(I3, I3));
mrepeat_minus_copy(make_tuple(I2, I3));
mrepeat_minus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
else if constexpr(MRepeat == 4 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
nrepeat_plus_copy(make_tuple(I2, I1));
mrepeat_plus_copy(make_tuple(I3, I1));
nrepeat_minus_copy(make_tuple(I3, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
mrepeat_plus_copy(make_tuple(I2, I0));
mrepeat_plus_copy(make_tuple(I3, I0));
nrepeat_plus_copy(make_tuple(I3, I1));
mrepeat_minus_copy(make_tuple(I2, I1));
mrepeat_minus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 4)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
nrepeat_plus_copy(make_tuple(I0, I3));
mrepeat_plus_copy(make_tuple(I1, I3));
nrepeat_minus_copy(make_tuple(I1, I2));
nrepeat_minus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
nrepeat_plus_copy(make_tuple(I0, I2));
mrepeat_plus_copy(make_tuple(I1, I2));
nrepeat_plus_copy(make_tuple(I1, I3));
mrepeat_minus_copy(make_tuple(I0, I3));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
if constexpr(CAccessOrderMRepeatNRepeat)
{
nrepeat_plus_copy(make_tuple(I0, I1));
mrepeat_plus_copy(make_tuple(I1, I1));
nrepeat_minus_copy(make_tuple(I1, I0));
}
else
{
mrepeat_plus_copy(make_tuple(I1, I0));
nrepeat_plus_copy(make_tuple(I1, I1));
mrepeat_minus_copy(make_tuple(I0, I1));
}
}
else if constexpr(MRepeat == 2 && NRepeat == 1)
{
init_copy(make_tuple(I0, I0));
mrepeat_plus_copy(make_tuple(I1, I0));
}
else if constexpr(MRepeat == 1 && NRepeat == 2)
{
init_copy(make_tuple(I0, I0));
nrepeat_plus_copy(make_tuple(I0, I1));
}
else if constexpr(MRepeat == 1 && NRepeat == 1)
{
init_copy(make_tuple(I0, I0));
}
}
#endif
}
}; // namespace ck
} // namespace ck
#endif

View File

@@ -0,0 +1,229 @@
#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP
#define CK_THREADWISE_CONTRACTION_DLOPS_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1]
// Tensor element can be vectorized data
// Assume:
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
// known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
{
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
// TODO remove this restriction
static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto TK = TKLengths{}[I0];
constexpr auto TM0 = TMLengths{}[I0];
constexpr auto TM1 = TMLengths{}[I1];
constexpr auto TN0 = TNLengths{}[I0];
constexpr auto TN1 = TNLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK, 1>{}([&](auto tk) {
static_for<0, TM0, 1>{}([&](auto tm0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN0, 1>{}([&](auto tn0) {
static_for<0, TN1, 1>{}([&](auto tn1) {
constexpr index_t a_offset =
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk, tm0, tm1));
constexpr index_t b_offset =
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk, tn0, tn1));
constexpr index_t c_offset =
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
});
});
});
});
});
}
};
// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1]
// Tensor element can be vectorized data
// Assume:
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
// known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
// TODO remove this restriction
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t TK0 = TKLengths{}[I0];
constexpr index_t TK1 = TKLengths{}[I1];
constexpr index_t TM0 = TMLengths{}[I0];
constexpr index_t TM1 = TMLengths{}[I1];
constexpr index_t TN0 = TNLengths{}[I0];
constexpr index_t TN1 = TNLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM0, 1>{}([&](auto tm0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN0, 1>{}([&](auto tn0) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t a_offset =
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
constexpr index_t b_offset =
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset =
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
inner_product<a_vector_t, b_vector_t, FloatC>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
});
});
});
});
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,160 @@
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
index_t H,
index_t W,
typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3
{
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
remove_cv_t<remove_reference_t<FloatA>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
remove_cv_t<remove_reference_t<FloatB>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
remove_cv_t<remove_reference_t<FloatC>>>::value &&
"wrong! inconsistent type");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto E = ADesc{}.GetLength(I0);
constexpr auto K = ADesc{}.GetLength(I1);
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, E, 1>{}([&](auto e) {
static_for<0, K, 1>{}([&](auto k) {
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(e, k));
if constexpr(H == 2 && W == 2)
{
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 1));
constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
constexpr index_t b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 1));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 1));
constexpr index_t c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
constexpr index_t c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 1));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
else if constexpr(H == 4 && W == 1)
{
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 0, 0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 1, 0));
constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 2, 0));
constexpr index_t b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, 3, 0));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 0, 0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 1, 0));
constexpr index_t c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 2, 0));
constexpr index_t c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, 3, 0));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
else
{
static_for<0, H, 1>{}([&](auto h) {
static_for<0, W, 1>{}([&](auto w) {
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(e, 0, h, w));
constexpr index_t c_offset =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w));
#if 0
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
#else
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset>{}],
c_buf(Number<c_offset>{}));
#endif
});
});
}
});
});
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,59 @@
#ifndef CK_THREADWISE_TENSOR_SET_HPP
#define CK_THREADWISE_TENSOR_SET_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// Assume:
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 4. use #-step
template <typename Data,
typename Desc,
typename SliceLengths,
typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceSet_v1
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
template <typename OriginIdx, typename Buffer>
__device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const
{
static_assert(Desc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(is_known_at_compile_time<remove_cv_t<remove_reference_t<OriginIdx>>>::value,
"wrong! OriginIdx need to be known at compile-time");
// Desc is known at compile-time
constexpr auto desc = remove_cv_t<remove_reference_t<Desc>>{};
// OriginIdx is known at compile-time
constexpr auto origin_idx = to_multi_index(OriginIdx{});
static_ford<SliceLengths>{}([&](auto access_idx) {
constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
constexpr bool is_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord);
constexpr index_t offset = coord.GetOffset();
if constexpr(is_valid)
{
buf(Number<offset>{}) = initial_value;
}
});
}
};
} // namespace ck
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,779 @@
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace ck {
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
template <typename SliceLengths,
InMemoryDataOperationEnum_t DstInMemOp,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
typename SrcVectorTensorLengths,
typename DstVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder,
typename DstVectorTensorContiguousDimOrder,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
struct ThreadwiseTensorSliceTransfer_v3r1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
{
// TODO: fix this
static_assert(is_same<SrcData, DstData>::value,
"wrong! current implementation assume SrcData and DstData are same type");
static_for<0, nDim, 1>{}([](auto i) {
static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 &&
SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0,
"wrong!");
});
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, typename SrcStepHacks>
__device__ void
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
// tensor descriptor for src_vector
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
container_reverse_exclusive_scan(
container_reorder_given_new2old(src_vector_tensor_lengths,
SrcVectorTensorContiguousDimOrder{}),
math::multiplies{},
I1),
SrcVectorTensorContiguousDimOrder{});
constexpr auto src_vector_desc =
make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths),
sequence_to_tuple_of_number(src_vector_tensor_strides));
// access order and lengths
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
});
return make_tensor_coordinate_step(
src_desc, forward_step_idx, src_step_hacks[I0][i]);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
});
return make_tensor_coordinate_step(
src_desc, backward_step_idx, src_step_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_vector_tensor_lengths;
}();
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
using src_vector_t = typename decltype(src_vector)::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// copy data from src_buf to src_vector
src_vector.template AsType<src_vector_t>()(I0) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
// copy data from src_vector to buffer_
static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
constexpr index_t src_vector_offset =
src_vector_desc.CalculateOffset(src_vector_idx);
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx);
buffer_(Number<buffer_offset>{}) =
src_vector.template AsType<SrcData>()[Number<src_vector_offset>{}];
});
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
template <typename DstBuffer, typename DstStepHacks>
__device__ void
RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
{
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
"wrong!");
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
// tensor descriptor for dst_vector
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new(
container_reverse_exclusive_scan(
container_reorder_given_new2old(dst_vector_tensor_lengths,
DstVectorTensorContiguousDimOrder{}),
math::multiplies{},
I1),
DstVectorTensorContiguousDimOrder{});
constexpr auto dst_vector_desc =
make_naive_tensor_descriptor(sequence_to_tuple_of_number(dst_vector_tensor_lengths),
sequence_to_tuple_of_number(dst_vector_tensor_strides));
// dst access order and lengths
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
});
return make_tensor_coordinate_step(
dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
});
return make_tensor_coordinate_step(
dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_vector_tensor_lengths;
}();
vector_type_maker_t<DstData, dst_vector_desc.GetElementSpaceSize()> dst_vector;
// copy data from buffer_ to dst_vector (also cast from SrcData to DstData)
static_ford<DstVectorTensorLengths>{}([&](auto dst_vector_idx_) {
constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_);
constexpr index_t buffer_offset =
buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx);
constexpr index_t dst_vector_offset =
dst_vector_desc.CalculateOffset(dst_vector_idx);
dst_vector.template AsType<DstData>()(Number<dst_vector_offset>{}) =
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]);
});
using dst_vector_t = typename decltype(dst_vector)::type;
// copy data from dst_vector to dst_buf
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunRead(src_desc, src_buf, src_step_hacks);
}
template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
constexpr auto dst_step_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
RunWrite(dst_desc, dst_buf, dst_step_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_vector_tensor_lengths;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step;
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_vector_tensor_lengths;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
static constexpr auto buffer_desc_ =
make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{}));
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_, true> buffer_;
SrcCoord src_coord_;
DstCoord dst_coord_;
};
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-step
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation
// 3. vector access on src
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
typename SrcVectorTensorLengths,
typename SrcVectorTensorContiguousDimOrder,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v4r1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_for<0, nDim, 1>{}([](auto i) {
static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!");
});
}
template <typename SrcRefToOriginDisplacement,
typename DstOriginIdx,
typename SrcBuffer,
typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcRefToOriginDisplacement&,
const SrcBuffer& src_buf,
const DstDesc&,
const DstOriginIdx&,
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
remove_cv_t<remove_reference_t<SrcData>>>::value &&
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
remove_cv_t<remove_reference_t<DstData>>>::value,
"wrong! SrcBuffer or DstBuffer data type is wrong");
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
static_assert(
is_known_at_compile_time<
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time");
// SrcDesc and DstDesc are known at compile-time
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
// tensor descriptor for src_vector
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
container_reverse_exclusive_scan(
container_reorder_given_new2old(src_vector_tensor_lengths,
SrcVectorTensorContiguousDimOrder{}),
math::multiplies{},
I1),
SrcVectorTensorContiguousDimOrder{});
constexpr auto src_vector_desc =
make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths),
sequence_to_tuple_of_number(src_vector_tensor_strides));
// access order and lengths
constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// position in slice window
constexpr auto data_to_origin_disp_idx =
ordered_access_idx.ReorderGivenOld2New(dim_access_order) *
src_vector_tensor_lengths;
// src coordinate at starting point of src_vector
constexpr auto src_ref_to_data_disp_idx =
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
constexpr auto src_ref_to_data_disp_coord_step =
make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
auto src_data_coord = src_ref_coord_;
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
using src_vector_t = typename decltype(src_vector)::type;
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_data_coord);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(I0) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
// copy data from src_vector into dst_buf (also cast from SrcData to DstData)
static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
constexpr index_t src_vector_offset =
src_vector_desc.CalculateOffset(src_vector_idx);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>{}(
src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]);
});
});
}
template <typename SrcSliceMoveStepIdx>
__device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx)
{
constexpr auto src_desc = SrcDesc{};
const auto src_slice_move_step_iter =
make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
}
private:
SrcCoord src_ref_coord_;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,801 @@
#ifndef CK_XDLOPS_GEMM_HPP
#define CK_XDLOPS_GEMM_HPP
#include "common_header.hpp"
#include "math.hpp"
#include "amd_xdlops.hpp"
namespace ck {
enum struct mfma_instr
{
/// fp32
mfma_f32_32x32x1xf32 = 0,
mfma_f32_16x16x1xf32,
mfma_f32_4x4x1xf32,
mfma_f32_32x32x2xf32, // k reduction
mfma_f32_16x16x4xf32, // k reduction
/// fp16
mfma_f32_32x32x4f16,
mfma_f32_16x16x4f16,
mfma_f32_4x4x4f16,
mfma_f32_32x32x8f16, // k reduction
mfma_f32_16x16x16f16, // k reduction
/// bfp16
mfma_f32_32x32x2bf16,
mfma_f32_16x16x2bf16,
mfma_f32_4x4x2bf16,
mfma_f32_32x32x4bf16, // k reduction
mfma_f32_16x16x8bf16, // k reduction
};
template <mfma_instr instr>
struct mfma_info;
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 1;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x2f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x4f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 1;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
// treat 4x4x1 as a single-blk 4x64 mfma
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 1;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 1;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_4x4x1f32<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 8;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x8f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 16;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x16f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 4;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 4;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 4;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t COffset,
class FloatA,
class FloatB,
class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_4x4x4f16<MPerXdlops, NPerXdlops, COffset>::Run(a, b, reg_c);
}
};
#if 0
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 2;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 2;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 4;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 32;
static constexpr index_t n = 32;
static constexpr index_t k = 4;
static constexpr index_t cycles = 64;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 8;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = wave_size / num_threads_blk;
static constexpr index_t num_output_blks = 4;
static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks;
static constexpr index_t m = 16;
static constexpr index_t n = 16;
static constexpr index_t k = 2;
static constexpr index_t cycles = 32;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
}
};
template <>
struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_blk = 1;
static constexpr index_t num_regs_blk = group_size * num_groups_blk;
static constexpr index_t num_threads_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t num_regs_xdlops = 4;
static constexpr index_t m = 4;
static constexpr index_t n = 64;
static constexpr index_t k = 2;
static constexpr index_t cycles = 8;
static constexpr index_t k_base = 2;
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t AStride,
index_t BStride,
class FloatA,
class FloatB,
class FloatC>
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
{
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
}
};
#endif
template <mfma_instr instr, index_t MPerXdlops_, index_t NPerXdlops_>
struct xdlops_info
{
static constexpr auto mfma_type = mfma_info<instr>{};
static constexpr index_t MPerXdlops = MPerXdlops_;
static constexpr index_t NPerXdlops = NPerXdlops_;
static constexpr bool IsABroadcast()
{
static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
return true;
}
static constexpr bool IsKReduction()
{
return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1);
}
static constexpr index_t GetKPerXdlops()
{
return IsKReduction() ? mfma_type.num_input_blks : 1;
}
static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; }
};
template <class base_type, index_t MPerWave, index_t NPerWave, index_t KPack>
struct XdlopsGemm
{
template <class base_type_ = base_type,
index_t MPerWave_ = MPerWave,
index_t NPerWave_ = NPerWave>
static constexpr auto GetXdlopsInfo();
template <>
static constexpr auto GetXdlopsInfo<float, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 64, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x1xf32, 32, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x1xf32, 16, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 8, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x1xf32, 4, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2xf32, 32, 32>{};
}
template <>
static constexpr auto GetXdlopsInfo<float, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4xf32, 16, 16>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 64, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4f16, 32, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x8f16, 32, 32>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x16f16, 16, 16>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x4f16, 16, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 8, 64>{};
}
template <>
static constexpr auto GetXdlopsInfo<half_t, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x4f16, 4, 64>{};
}
#if 0
template <>
static constexpr auto GetXdlopsInfo<ushort, 128, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 128>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 64, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 8, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 4, 64>()
{
return xdlops_info<mfma_instr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 32, 32>()
{
return xdlops_info<mfma_instr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
}
template <>
static constexpr auto GetXdlopsInfo<ushort, 16, 16>()
{
return xdlops_info<mfma_instr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
}
#endif
using CIndex = MultiIndex<2>;
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
}
__host__ __device__ constexpr XdlopsGemm()
{
static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
NPerXdlops == 64,
"Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk");
static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m,
"m != num_input_blks * num_regs_blk");
static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks ||
mfma_type.num_output_blks == 1,
"incorrect num_output_blks");
static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n,
"num_regs_blk incorrect");
static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!");
}
__device__ static constexpr index_t GetRegSizePerXdlops()
{
return MPerXdlops * NPerXdlops / mfma_type.wave_size;
}
template <class ADesc,
class BDesc,
class CDesc,
index_t m0,
index_t n0,
class FloatA,
class FloatB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
is_same<base_type, ushort>::value,
"base base_type must be float, half, ushort!");
static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base");
constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops();
static_for<0, KPack, mfma_type.k_base>{}([&](auto k) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k));
constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k));
mfma_type.template run<MPerXdlops, NPerXdlops, c_offset>(
p_a_wave[Number<a_offset / mfma_type.k_base>{}],
p_b_wave[Number<b_offset / mfma_type.k_base>{}],
p_c_thread);
});
}
__device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
{
const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size;
const index_t blk_id = laneId / mfma_type.num_threads_blk;
const index_t blk_td = laneId % mfma_type.num_threads_blk;
index_t n_offset = blk_i * mfma_type.n + blk_td;
index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size;
return CIndex{m_offset, n_offset};
}
static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats;
static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats;
static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops;
static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops;
static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction();
static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast();
static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops();
static constexpr auto GetBlkId(const index_t lane_id)
{
return lane_id / mfma_type.num_threads_blk;
}
static constexpr auto GetBlkTd(const index_t lane_id)
{
return lane_id % mfma_type.num_threads_blk;
}
static constexpr auto mfma_type = GetXdlopsInfo().mfma_type;
struct CLayout
{
__host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; }
__host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; }
__host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; }
__host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; }
__device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; }
__device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; }
__device__ static constexpr index_t GetNumXdlops()
{
return MPerXdlops * NPerXdlops /
(mfma_type.m * mfma_type.n * mfma_type.num_output_blks);
}
};
__host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; }
};
} // namespace ck
#endif

View File

@@ -0,0 +1,44 @@
#ifndef CK_AMD_ADDRESS_SPACE_HPP
#define CK_AMD_ADDRESS_SPACE_HPP
#include "config.hpp"
#include "c_style_pointer_cast.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck {
enum AddressSpaceEnum_t
{
Generic,
Global,
Lds,
Sgpr,
Vgpr,
};
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CONSTANT*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck
#endif

View File

@@ -0,0 +1,681 @@
#ifndef CK_AMD_BUFFER_ADDRESSING_HPP
#define CK_AMD_BUFFER_ADDRESSING_HPP
#include "data_type.hpp"
namespace ck {
template <typename T>
union BufferResource
{
// 128 bit SGPRs to supply buffer resource in buffer instructions
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
int32x4_t content;
StaticallyIndexedArray<T*, 2> address;
StaticallyIndexedArray<int32_t, 4> range;
StaticallyIndexedArray<int32_t, 4> config;
};
template <typename T>
__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size)
{
BufferResource<T> wave_buffer_resource;
// wavewise base address (64 bit)
wave_buffer_resource.address(Number<0>{}) = const_cast<remove_cv_t<T>*>(p_wave);
// wavewise range (32 bit)
wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T);
// wavewise setting (32 bit)
wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD;
return wave_buffer_resource.content;
}
// load
__device__ int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
__device__ int8x2_t
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
__device__ int8x4_t
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
__device__ int16_t
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32_t
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
__device__ int32x2_t
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
__device__ int32x4_t
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
// half
__device__ half_t
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
__device__ half2_t
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
__device__ half4_t
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
// float
__device__ float
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
__device__ float2_t
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
__device__ float4_t
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
// store
__device__ void
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
__device__ void
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
__device__ void
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
__device__ void
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
__device__ void
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
__device__ void
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
__device__ void
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
// half
__device__ void
llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
__device__ void
llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
__device__ void
llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
// float
__device__ void
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
__device__ void
llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
__device__ void
llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
template <typename T, index_t N>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_fp32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_fp32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp;
tmp.AsType<float4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<float4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
0);
return tmp.AsType<float8_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_fp16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_fp16x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
#if 0
vector_type<half_t, 8> tmp;
tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<half4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(half_t),
0);
return tmp.AsType<half8_t>()(Number<0>{});
#else
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<half8_t>(tmp);
#endif
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 8)
{
vector_type<int32_t, 8> tmp;
tmp.AsType<int32x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int32x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
0);
return tmp.AsType<int32x8_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i8(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
}
else if constexpr(N == 2)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x2_t>(tmp);
#endif
}
else if constexpr(N == 4)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
return llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
#else
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x4_t>(tmp);
#endif
}
else if constexpr(N == 8)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 8> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
0);
return tmp.AsType<int8x8_t>()(Number<0>{});
#else
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x8_t>(tmp);
#endif
}
else if constexpr(N == 16)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
vector_type<int8_t, 16> tmp;
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
tmp.AsType<int8x4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int8_t),
0);
tmp.AsType<int8x4_t>()(Number<2>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int8_t),
0);
tmp.AsType<int8x4_t>()(Number<3>{}) =
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int8_t),
0);
return tmp.AsType<int8x16_t>()(Number<0>{});
#else
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
return as_type<int8x16_t>(tmp);
#endif
}
}
}
template <typename T, index_t N>
__device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
"wrong! not implemented");
if constexpr(is_same<T, float>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int32_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, int8_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#else
llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#endif
}
else if constexpr(N == 4)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#else
llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
#endif
}
else if constexpr(N == 8)
{
llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 16)
{
llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
}
else if constexpr(is_same<T, half_t>::value)
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
0);
}
}
}
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size);
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff;
return amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
#endif
}
// buffer_load requires:
// 1) p_src_wave must be in global memory space
// 2) p_src_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ typename vector_type_maker<T, N>::type::type
amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_element_offset,
bool src_thread_element_valid,
index_t src_element_space_size,
T customized_value)
{
const int32x4_t src_wave_buffer_resource =
make_wave_buffer_resource(p_src_wave, src_element_space_size);
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(customized_value);
}
// buffer_store requires:
// 1) p_dst_wave must be global memory
// 2) p_dst_wave to be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template <typename T, index_t N>
__device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::type src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size)
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size);
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
using vector_t = typename vector_type_maker<T, N>::type::type;
using scalar_t = typename scalar_type<vector_t>::type;
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff;
amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
#else
if(dst_thread_element_valid)
{
amd_buffer_store_impl<scalar_t, vector_size>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
#endif
}
} // namespace ck
#endif

View File

@@ -0,0 +1,356 @@
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace ck {
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{
asm volatile("\n \
v_fmac_f32 %0, %2, %3 \n \
v_fmac_f32 %1, %2, %4 \n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
{
asm volatile("\n \
v_fmac_f32 %0, %4, %5 \n \
v_fmac_f32 %1, %4, %6 \n \
v_fmac_f32 %2, %4, %7 \n \
v_fmac_f32 %3, %4, %8 \n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
{
asm volatile("\n \
v_dot2_f32_f16 %0, %2, %3, %0\n \
v_dot2_f32_f16 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
{
// TODO remove pointer casting
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
// do dot2 two times
asm volatile("\n \
v_dot2_f32_f16 %0, %2, %4, %0\n \
v_dot2_f32_f16 %1, %2, %6, %1\n \
v_dot2_f32_f16 %0, %3, %5, %0\n \
v_dot2_f32_f16 %1, %3, %7, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(p_a_half2[0]),
"v"(p_a_half2[1]),
"v"(p_b0_half2[0]),
"v"(p_b0_half2[1]),
"v"(p_b1_half2[0]),
"v"(p_b1_half2[1]),
"0"(c0),
"1"(c1));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half2_t a,
half2_t b0,
half2_t b1,
half2_t b2,
half2_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
asm volatile("\n \
v_dot2_f32_f16 %0, %4, %5, %0\n \
v_dot2_f32_f16 %1, %4, %6, %1\n \
v_dot2_f32_f16 %2, %4, %7, %2\n \
v_dot2_f32_f16 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(half4_t a,
half4_t b0,
half4_t b1,
half4_t b2,
half4_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
// TODO remove pointer casting
const half2_t* p_a_half2 = c_style_pointer_cast<const half2_t*>(&a);
const half2_t* p_b0_half2 = c_style_pointer_cast<const half2_t*>(&b0);
const half2_t* p_b1_half2 = c_style_pointer_cast<const half2_t*>(&b1);
const half2_t* p_b2_half2 = c_style_pointer_cast<const half2_t*>(&b2);
const half2_t* p_b3_half2 = c_style_pointer_cast<const half2_t*>(&b3);
// do dot2 two times
asm volatile("\n \
v_dot2_f32_f16 %0, %4, %6, %0\n \
v_dot2_f32_f16 %1, %4, %8, %1\n \
v_dot2_f32_f16 %2, %4, %10, %2\n \
v_dot2_f32_f16 %3, %4, %12, %3\n \
v_dot2_f32_f16 %0, %5, %7, %0\n \
v_dot2_f32_f16 %1, %5, %9, %1\n \
v_dot2_f32_f16 %2, %5, %11, %2\n \
v_dot2_f32_f16 %3, %5, %13, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(p_a_half2[0]),
"v"(p_a_half2[1]),
"v"(p_b0_half2[0]),
"v"(p_b0_half2[1]),
"v"(p_b1_half2[0]),
"v"(p_b1_half2[1]),
"v"(p_b2_half2[0]),
"v"(p_b2_half2[1]),
"v"(p_b3_half2[0]),
"v"(p_b3_half2[1]),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
}
__device__ void amd_assembly_outer_product_1x4(half8_t a,
half8_t b0,
half8_t b1,
half8_t b2,
half8_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
// TODO remove pointer casting
const half4_t* p_a_half4 = c_style_pointer_cast<const half4_t*>(&a);
const half4_t* p_b0_half4 = c_style_pointer_cast<const half4_t*>(&b0);
const half4_t* p_b1_half4 = c_style_pointer_cast<const half4_t*>(&b1);
const half4_t* p_b2_half4 = c_style_pointer_cast<const half4_t*>(&b2);
const half4_t* p_b3_half4 = c_style_pointer_cast<const half4_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
}
__device__ void amd_assembly_outer_product_1x4(half16_t a,
half16_t b0,
half16_t b1,
half16_t b2,
half16_t b3,
float& c0,
float& c1,
float& c2,
float& c3)
{
// TODO remove pointer casting
const half8_t* p_a_half8 = c_style_pointer_cast<const half8_t*>(&a);
const half8_t* p_b0_half8 = c_style_pointer_cast<const half8_t*>(&b0);
const half8_t* p_b1_half8 = c_style_pointer_cast<const half8_t*>(&b1);
const half8_t* p_b2_half8 = c_style_pointer_cast<const half8_t*>(&b2);
const half8_t* p_b3_half8 = c_style_pointer_cast<const half8_t*>(&b3);
amd_assembly_outer_product_1x4(
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
amd_assembly_outer_product_1x4(
p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void
amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %2, %3, %0\n \
v_dot4_i32_i8 %1, %2, %4, %1\n \
"
: "=v"(c0), "=v"(c1)
: "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"0"(c0),
"1"(c1));
#else
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
#endif
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c2 += inner_product(a, b2)
// c3 += inner_product(a, b3)
__device__ void amd_assembly_outer_product_1x4(int8x4_t a,
int8x4_t b0,
int8x4_t b1,
int8x4_t b2,
int8x4_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
#if 1
asm volatile("\n \
v_dot4_i32_i8 %0, %4, %5, %0\n \
v_dot4_i32_i8 %1, %4, %6, %1\n \
v_dot4_i32_i8 %2, %4, %7, %2\n \
v_dot4_i32_i8 %3, %4, %8, %3\n \
"
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
: "v"(as_type<int32_t>(a)),
"v"(as_type<int32_t>(b0)),
"v"(as_type<int32_t>(b1)),
"v"(as_type<int32_t>(b2)),
"v"(as_type<int32_t>(b3)),
"0"(c0),
"1"(c1),
"2"(c2),
"3"(c3));
#else
c0 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b0), c0, false);
c1 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b1), c1, false);
c2 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b2), c2, false);
c3 = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b3), c3, false);
#endif
}
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
int8x8_t b0,
int8x8_t b1,
int8x8_t b2,
int8x8_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
}
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
int8x16_t b0,
int8x16_t b1,
int8x16_t b2,
int8x16_t b3,
int32_t& c0,
int32_t& c1,
int32_t& c2,
int32_t& c3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I0],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I1],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I2],
c0,
c1,
c2,
c3);
amd_assembly_outer_product_1x4(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b0}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b1}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b2}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b3}.AsType<int8x4_t>()[I3],
c0,
c1,
c2,
c3);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,11 @@
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
#define CK_AMD_LLVM_INTRINSIC_HPP
#include "data_type.hpp"
namespace ck {
__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
} // namespace ck
#endif

View File

@@ -0,0 +1,499 @@
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
namespace ck {
// A, B, C, cbsz, abid, blgp
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x1f32;
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x2f32;
template <index_t COffset>
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x1f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x1f32;
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x4f16;
template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x8f16;
template <index_t COffset>
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x16f16;
template <index_t COffset>
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x4f16;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x4f16;
template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
}
};
#if 0
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x2bf16;
template <>
struct intrin_mfma_f32_4x4x2bf16<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x2bf16<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
#endif
} // namespace ck
#endif

View File

@@ -0,0 +1,63 @@
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
#include "functional2.hpp"
#include "sequence.hpp"
namespace ck {
template <typename TData, index_t NSize>
struct Array
{
using type = Array;
using data_type = TData;
TData mData[NSize];
__host__ __device__ static constexpr index_t Size() { return NSize; }
__host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; }
__host__ __device__ constexpr TData& At(index_t i) { return mData[i]; }
__host__ __device__ constexpr const TData& operator[](index_t i) const { return At(i); }
__host__ __device__ constexpr TData& operator()(index_t i) { return At(i); }
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
};
// empty Array
template <typename TData>
struct Array<TData, 0>
{
using type = Array;
using data_type = TData;
__host__ __device__ static constexpr index_t Size() { return 0; }
};
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{
using data_type = remove_cv_t<remove_reference_t<X>>;
return Array<data_type, sizeof...(Xs) + 1>{{std::forward<X>(x), std::forward<Xs>(xs)...}};
}
// make empty array
template <typename X>
__host__ __device__ constexpr auto make_array()
{
return Array<X, 0>{};
}
} // namespace ck
#endif

View File

@@ -0,0 +1,77 @@
#ifndef CK_ARRAY_MULTI_INDEX_HPP
#define CK_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
namespace ck {
template <index_t N>
using MultiIndex = Array<index_t, N>;
template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
{
return make_array<index_t>(index_t{xs}...);
}
template <index_t NSize>
__host__ __device__ constexpr auto make_zero_multi_index()
{
return unpack([](auto... xs) { return make_multi_index(xs...); },
typename uniform_sequence_gen<NSize, 0>::type{});
}
template <typename T>
__host__ __device__ constexpr auto to_multi_index(const T& x)
{
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
}
template <index_t NSize, typename X>
__host__ __device__ constexpr auto operator+=(MultiIndex<NSize>& y, const X& x)
{
static_assert(X::Size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
template <index_t NSize, typename X>
__host__ __device__ constexpr auto operator-=(MultiIndex<NSize>& y, const X& x)
{
static_assert(X::Size() == NSize, "wrong! size not the same");
static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
template <index_t NSize, typename T>
__host__ __device__ constexpr auto operator+(const MultiIndex<NSize>& a, const T& b)
{
using type = MultiIndex<NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; });
return r;
}
template <index_t NSize, typename T>
__host__ __device__ constexpr auto operator-(const MultiIndex<NSize>& a, const T& b)
{
using type = MultiIndex<NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; });
return r;
}
template <index_t NSize, typename T>
__host__ __device__ constexpr auto operator*(const MultiIndex<NSize>& a, const T& b)
{
using type = MultiIndex<NSize>;
static_assert(T::Size() == NSize, "wrong! size not the same");
type r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; });
return r;
}
} // namespace ck
#endif

View File

@@ -0,0 +1,22 @@
#ifndef CK_C_STYLE_POINTER_CAST_HPP
#define CK_C_STYLE_POINTER_CAST_HPP
#include "type.hpp"
#include "enable_if.hpp"
namespace ck {
template <typename PY,
typename PX,
typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
__host__ __device__ PY c_style_pointer_cast(PX p_x)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wcast-align"
return (PY)p_x; // NOLINT(old-style-cast, cast-align)
#pragma clang diagnostic pop
}
} // namespace ck
#endif

View File

@@ -0,0 +1,46 @@
#ifndef CK_COMMON_HEADER_HPP
#define CK_COMMON_HEADER_HPP
#include "config.hpp"
#include "array.hpp"
#include "container_helper.hpp"
#include "statically_indexed_array.hpp"
#include "container_element_picker.hpp"
#include "multi_index.hpp"
#include "data_type.hpp"
#include "data_type_enum.hpp"
#include "data_type_enum_helper.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
#include "functional4.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp"
#include "math.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "synchronization.hpp"
#include "tuple.hpp"
#include "tuple_helper.hpp"
#include "type.hpp"
#include "magic_division.hpp"
#include "utility.hpp"
#include "c_style_pointer_cast.hpp"
#include "amd_address_space.hpp"
#include "amd_buffer_addressing.hpp"
#include "static_buffer.hpp"
#include "dynamic_buffer.hpp"
#include "inner_product.hpp"
// TODO: remove this
#if CK_USE_AMD_INLINE_ASM
#include "amd_inline_asm.hpp"
#endif
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp"
#endif
#endif

View File

@@ -0,0 +1,134 @@
#ifndef CK_CONFIG_AMD_HPP
#define CK_CONFIG_AMD_HPP
#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#endif
#include "bfloat16_dev.hpp"
// "Constant" address space for kernel parameter
#define CONSTANT __attribute__((address_space(4)))
// GPU target
// should enable one and only one GPU target
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
#error Need to define (only) one GPU target
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU 2
#endif
// buffer resourse
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(CK_AMD_GPU_GFX1030)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#endif
// FMA instruction
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900)
#define CK_USE_AMD_V_MAC_F32
#elif defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90a) || \
defined(CK_AMD_GPU_GFX1030)
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#endif
// multi index
#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0
// AMD inline asm
#ifndef CK_USE_AMD_INLINE_ASM
#define CK_USE_AMD_INLINE_ASM 1
#endif
// AMD inner product (DLOP)
#ifndef CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
#endif
// AMD buffer addressing
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
#define CK_USE_AMD_BUFFER_ADDRESSING 1
#endif
// only gfx908 support native floating point atomic add
#ifndef CK_USE_AMD_BUFFER_ATOMIC_FADD
#define CK_USE_AMD_BUFFER_ATOMIC_FADD 0
#endif
// AMD XDLOPS
#ifndef CK_USE_AMD_XDLOPS
#define CK_USE_AMD_XDLOPS 0
#endif
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#endif
// experimental implementation
#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#endif
#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1
#endif
// pass tensor descriptor by value or void*
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
// merge transformation use magic number division
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
// thread-invariant, otherwise it's a bug
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
#ifndef CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
#endif
// workaround for compiler crash when compiling recursive lambda
#ifndef CK_WORKAROUND_SWDEV_275126
#define CK_WORKAROUND_SWDEV_275126 1
#endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1
#endif
// workaround for compiler crash when using buffer load/store for i8
#ifndef CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
namespace ck {
enum InMemoryDataOperationEnum_t
{
Set,
AtomicAdd
};
// index type
using index_t = int32_t;
} // namespace ck
#endif

View File

@@ -0,0 +1,155 @@
#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP
#define CK_CONTAINER_ELEMENT_PICKER_HPP
#include "functional2.hpp"
#include "sequence.hpp"
namespace ck {
// Arr: Array or StaticallyIndexedArray
// Picks: Sequence<...>
template <typename Arr, typename Picks>
struct ContainerElementPicker
{
using type = ContainerElementPicker;
#if 0
using data_type = typename Arr::data_type;
#endif
__host__ __device__ constexpr ContainerElementPicker() = delete;
__host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array}
{
constexpr index_t imax =
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I> i) const
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[i];
return mArray[IP];
}
template <index_t I>
__host__ __device__ constexpr auto& At(Number<I> i)
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[i];
return mArray(IP);
}
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
{
return At(i);
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return At(i);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
private:
Arr& mArray;
};
// Arr: Array or StaticallyIndexedArray
// Picks: Sequence<...>
template <typename Arr, typename Picks>
struct ConstantContainerElementPicker
{
using type = ConstantContainerElementPicker;
#if 0
using data_type = typename Arr::data_type;
#endif
__host__ __device__ constexpr ConstantContainerElementPicker() = delete;
__host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array}
{
constexpr index_t imax =
reduce_on_sequence(Picks{}, math::maximize<index_t>{}, Number<0>{});
static_assert(imax < Arr::Size(), "wrong! exceeding # array element");
}
__host__ __device__ static constexpr auto Size() { return Picks::Size(); }
template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I> i) const
{
static_assert(I < Size(), "wrong!");
constexpr auto IP = Picks{}[i];
return mArray[IP];
}
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
{
return At(i);
}
private:
const Arr& mArray;
};
template <typename Arr, typename Picks, typename X>
__host__ __device__ constexpr auto operator+=(ContainerElementPicker<Arr, Picks>& y, const X& x)
{
using Y = ContainerElementPicker<Arr, Picks>;
constexpr index_t nsize = Y::Size();
static_assert(nsize == X::Size(), "wrong! size not the same");
static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
template <typename Arr, typename Picks, typename X>
__host__ __device__ constexpr auto operator-=(ContainerElementPicker<Arr, Picks>& y, const X& x)
{
using Y = ContainerElementPicker<Arr, Picks>;
constexpr index_t nsize = Y::Size();
static_assert(nsize == X::Size(), "wrong! size not the same");
static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_container_element(Arr& a, Picks)
{
return ContainerElementPicker<Arr, Picks>(a);
}
template <typename Arr, typename Picks>
__host__ __device__ constexpr auto pick_container_element(const Arr& a, Picks)
{
return ConstantContainerElementPicker<Arr, Picks>(a);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,403 @@
#ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "array.hpp"
#include "tuple.hpp"
#include "tuple_helper.hpp"
#include "statically_indexed_array.hpp"
#include "container_element_picker.hpp"
namespace ck {
template <typename TData, index_t NSize>
__host__ __device__ constexpr auto container_push_back(const Array<TData, NSize>& a, const TData& x)
{
Array<TData, NSize + 1> r;
static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; });
r(Number<NSize>{}) = x;
return r;
}
template <typename... Ts, typename T>
__host__ __device__ constexpr auto container_push_front(const Tuple<Ts...>& a, const T& x)
{
return container_concat(make_tuple(x), a);
}
template <typename... Ts, typename T>
__host__ __device__ constexpr auto container_push_back(const Tuple<Ts...>& a, const T& x)
{
return container_concat(a, make_tuple(x));
}
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto
container_reorder_given_new2old(const Array<TData, NSize>& old_array, Sequence<IRs...> /*new2old*/)
{
static_assert(NSize == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_array(old_array[Number<IRs>{}]...);
}
template <typename TData, index_t NSize, index_t... IRs>
__host__ __device__ constexpr auto
container_reorder_given_old2new(const Array<TData, NSize>& old_array, Sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_array, typename sequence_map_inverse<decltype(old2new)>::type{});
}
template <typename... Ts, index_t... IRs>
__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple<Ts...>& old_tuple,
Sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
return make_tuple(old_tuple[Number<IRs>{}]...);
}
template <typename... Ts, index_t... IRs>
__host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple<Ts...>& old_tuple,
Sequence<IRs...> old2new)
{
return container_reorder_given_new2old(
old_tuple, typename sequence_map_inverse<decltype(old2new)>::type{});
}
template <index_t... Is, index_t... IRs>
__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence<Is...> /* old_seq */,
Sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
return Sequence<Sequence<Is...>::At(Number<IRs>{})...>{};
}
template <index_t... Is, index_t... IRs>
__host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is...> old_seq,
Sequence<IRs...> /* old2new */)
{
static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{};
return container_reorder_given_new2old(old_seq, new2old);
}
#if !CK_WORKAROUND_SWDEV_275126
// rocm-4.1 compiler would crash for recursive lambda
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::Size(),
index_t IStep = 1>
__host__ __device__ constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
Number<IBegin> = Number<0>{},
Number<IEnd> = Number<Container::Size()>{},
Number<IStep> = Number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto r_old) {
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
// recursively call f/fs
return fs(fs, i + Number<IStep>{}, r_new);
}
else
{
return r_new;
}
};
// start recursion
return f(f, Number<IBegin>{}, init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename Container,
typename Reduce,
typename ROld,
index_t I,
index_t IEnd,
index_t IStep>
__host__ __device__ constexpr auto container_reduce_impl(
const Container& x, Reduce reduce, ROld r_old, Number<I> i, Number<IEnd>, Number<IStep>)
{
auto r_new = reduce(x[i], r_old);
if constexpr(i.value < IEnd - IStep)
{
return container_reduce_impl(
x, reduce, r_new, i + Number<IStep>{}, Number<IEnd>{}, Number<IStep>{});
}
else
{
return r_new;
}
}
// rocm-4.1 compiler would crash for recursive lambda
// container reduce with initial value
template <typename Container,
typename Reduce,
typename Init,
index_t IBegin = 0,
index_t IEnd = Container::Size(),
index_t IStep = 1>
__host__ __device__ constexpr auto container_reduce(const Container& x,
Reduce reduce,
Init init,
Number<IBegin> = Number<0>{},
Number<IEnd> = Number<Container::Size()>{},
Number<IStep> = Number<1>{})
{
static_assert((IEnd - IBegin) % IStep == 0, "wrong!");
if constexpr(IEnd > IBegin)
{
return container_reduce_impl(
x, reduce, init, Number<IBegin>{}, Number<IEnd>{}, Number<IStep>{});
}
else
{
return init;
}
}
#endif
template <typename TData, index_t NSize, typename Reduce>
__host__ __device__ constexpr auto
container_reverse_inclusive_scan(const Array<TData, NSize>& x, Reduce f, TData init)
{
Array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[Number<0>{}]);
y(Number<0>{}) = r;
return y;
}
template <typename TData, index_t NSize, typename Reduce>
__host__ __device__ constexpr auto
container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData init)
{
Array<TData, NSize> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
y(i) = r;
r = f(r, x[i]);
});
y(Number<0>{}) = r;
return y;
}
template <index_t... Is, typename Reduce, index_t Init>
__host__ __device__ constexpr auto
container_reverse_exclusive_scan(const Sequence<Is...>& seq, Reduce f, Number<Init>)
{
return reverse_exclusive_scan_sequence(seq, f, Number<Init>{});
}
#if !CK_WORKAROUND_SWDEV_275126
// rocm4.1 compiler would crash with recursive lambda
template <typename... Xs, typename Reduce, typename Init>
__host__ __device__ constexpr auto
container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
// f is recursive function, fs is a dummy of f
// i is index, y_old is current scan, r_old is current reduction
auto f = [&](auto fs, auto i, auto y_old, auto r_old) {
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return fs(fs, i - Number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
};
// start recursion
return f(f, Number<NSize - 1>{}, make_tuple(init), init);
}
#else
// i is index, y_old is current scan, r_old is current reduction
template <typename... Xs, typename Reduce, index_t I, typename YOld, typename ROld>
__host__ __device__ constexpr auto container_reverse_exclusive_scan_impl(
const Tuple<Xs...>& x, Reduce reduce, Number<I> i, YOld y_old, ROld r_old)
{
auto r_new = reduce(x[i], r_old);
auto y_new = container_push_front(y_old, r_new);
if constexpr(i.value > 1)
{
// recursively call f/fs
return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new);
}
else
{
return y_new;
}
}
template <typename... Xs, typename Reduce, typename Init>
__host__ __device__ constexpr auto
container_reverse_exclusive_scan(const Tuple<Xs...>& x, Reduce reduce, Init init)
{
constexpr index_t NSize = sizeof...(Xs);
return container_reverse_exclusive_scan_impl(
x, reduce, Number<NSize - 1>{}, make_tuple(init), init);
}
#endif
// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<>
template <typename... Xs, typename Reduce, typename TData>
__host__ __device__ constexpr auto
container_reverse_inclusive_scan(const Tuple<Xs...>& x, Reduce f, TData init)
{
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> y;
TData r = init;
static_for<NSize - 1, 0, -1>{}([&](auto i) {
r = f(r, x[i]);
y(i) = r;
});
r = f(r, x[Number<0>{}]);
y(Number<0>{}) = r;
return y;
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys)
{
return container_concat(x, container_concat(ys...));
}
template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>
__host__ __device__ constexpr auto container_concat(const Container& x)
{
return x;
}
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr auto get_container_subset(const Array<T, N>& arr, Sequence<Is...>)
{
static_assert(N >= sizeof...(Is), "wrong! size");
return make_array(arr[Number<Is>{}]...);
}
template <typename... Ts, index_t... Is>
__host__ __device__ constexpr auto get_container_subset(const Tuple<Ts...>& tup, Sequence<Is...>)
{
static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size");
return make_tuple(tup[Number<Is>{}]...);
}
template <typename T, index_t N, index_t... Is>
__host__ __device__ constexpr void
set_container_subset(Array<T, N>& y, Sequence<Is...> picks, const Array<T, sizeof...(Is)>& x)
{
static_assert(N >= sizeof...(Is), "wrong! size");
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
template <typename... Ys, index_t... Is, typename... Xs>
__host__ __device__ constexpr void
set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>& x)
{
static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size");
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
template <typename Container>
__host__ __device__ constexpr auto to_tuple_of_number(const Container&)
{
static_assert(is_known_at_compile_time<Container>::value, "wrong!");
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Container::At(i);
return Number<tmp>{};
},
Container::Size());
}
template <index_t... Is>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
{
using Seq = Sequence<Is...>;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Seq::At(i);
return Number<tmp>{};
},
Seq::Size());
}
} // namespace ck
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,19 @@
#ifndef CK_DATA_TYPE_ENUM_HPP
#define CK_DATA_TYPE_ENUM_HPP
namespace ck {
enum DataTypeEnum_t
{
Half = 0,
Float = 1,
Int32 = 2,
Int8 = 3,
Int8x4 = 4,
BFloat16 = 5,
Double = 6,
Unknown = 100,
};
} // namespace ck
#endif

View File

@@ -0,0 +1,76 @@
#ifndef CK_DATA_TYPE_ENUM_HELPER_HPP
#define CK_DATA_TYPE_ENUM_HELPER_HPP
#include "data_type.hpp"
#include "data_type_enum.hpp"
namespace ck {
template <DataTypeEnum_t DataTypeEnum>
struct get_datatype_from_enum;
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Int8>
{
using type = int8_t;
};
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Int32>
{
using type = int32_t;
};
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Half>
{
using type = half_t;
};
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Float>
{
using type = float;
};
template <>
struct get_datatype_from_enum<DataTypeEnum_t::Double>
{
using type = double;
};
template <typename T>
struct get_datatype_enum_from_type;
template <>
struct get_datatype_enum_from_type<int8_t>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8;
};
template <>
struct get_datatype_enum_from_type<int32_t>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32;
};
template <>
struct get_datatype_enum_from_type<half_t>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half;
};
template <>
struct get_datatype_enum_from_type<float>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float;
};
template <>
struct get_datatype_enum_from_type<double>
{
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double;
};
} // namespace ck
#endif

View File

@@ -0,0 +1,246 @@
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp"
#include "enable_if.hpp"
namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
typename ElementSpaceSize,
bool InvalidElementUseNumericalZeroValue>
struct DynamicBuffer
{
using type = T;
T* p_data_;
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
}
__host__ __device__ constexpr DynamicBuffer(T* p_data,
ElementSpaceSize element_space_size,
T invalid_element_value)
: p_data_{p_data},
element_space_size_{element_space_size},
invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{
return BufferAddressSpace;
}
template <typename X,
typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{
// X contains multiple T
constexpr index_t scalar_per_t_vector =
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector =
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
#if CK_USE_AMD_BUFFER_ADDRESSING
bool constexpr use_amd_buffer_addressing = true;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global && use_amd_buffer_addressing)
{
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
if constexpr(InvalidElementUseNumericalZeroValue)
{
return amd_buffer_load_invalid_element_return_return_zero<
remove_cv_t<remove_reference_t<T>>,
t_per_x>(p_data_, i, is_valid_element, element_space_size_);
}
else
{
return amd_buffer_load_invalid_element_return_customized_value<
remove_cv_t<remove_reference_t<T>>,
t_per_x>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
}
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i]) : X{0};
}
else
{
return is_valid_element ? *c_style_pointer_cast<const X*>(&p_data_[i])
: X{invalid_element_value_};
}
}
}
template <typename X,
typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
constexpr index_t scalar_per_t_vector =
scalar_type<remove_cv_t<remove_reference_t<T>>>::vector_size;
constexpr index_t scalar_per_x_vector =
scalar_type<remove_cv_t<remove_reference_t<X>>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X need to be multiple T");
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cv_t<remove_reference_t<T>>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
#else
if(is_valid_element)
{
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
}
#endif
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
{
if(is_valid_element)
{
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
#else
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
// inefficient
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
int8_t>::value)
{
static_assert(
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
"wrong! not implemented for this combination, please add "
"implementation");
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int8_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int8_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int16_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int16_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x4_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x8_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x2_t*>(&x);
}
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
int8x16_t>::value &&
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
*c_style_pointer_cast<const int32x4_t*>(&x);
}
}
else
{
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
}
#endif
}
}
else
{
if(is_valid_element)
{
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
}
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
};
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true>{p, element_space_size};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
__host__ __device__ constexpr auto
make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, false>{
p, element_space_size, invalid_element_value};
}
} // namespace ck
#endif

View File

@@ -0,0 +1,13 @@
#ifndef CK_ENABLE_IF_HPP
#define CK_ENABLE_IF_HPP
namespace ck {
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck
#endif

View File

@@ -0,0 +1,116 @@
#ifndef CK_FUNCTIONAL_HPP
#define CK_FUNCTIONAL_HPP
#include "integral_constant.hpp"
#include "type.hpp"
namespace ck {
// TODO: right? wrong?
struct forwarder
{
template <typename T>
__host__ __device__ constexpr T&& operator()(T&& x) const
{
return static_cast<T&&>(x);
}
};
struct swallow
{
template <typename... Ts>
__host__ __device__ constexpr swallow(Ts&&...)
{
}
};
template <typename T>
struct logical_and
{
constexpr bool operator()(const T& x, const T& y) const { return x && y; }
};
template <typename T>
struct logical_or
{
constexpr bool operator()(const T& x, const T& y) const { return x || y; }
};
template <typename T>
struct logical_not
{
constexpr bool operator()(const T& x) const { return !x; }
};
// Emulate if constexpr
template <bool>
struct static_if;
template <>
struct static_if<true>
{
using Type = static_if<true>;
template <typename F>
__host__ __device__ constexpr auto operator()(F f) const
{
// This is a trick for compiler:
// Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will
// use it,
// this will make "f" a generic lambda, so that "f" won't be compiled
// until being
// instantiated here
f(forwarder{});
return Type{};
}
template <typename F>
__host__ __device__ static void Else(F)
{
}
};
template <>
struct static_if<false>
{
using Type = static_if<false>;
template <typename F>
__host__ __device__ constexpr auto operator()(F) const
{
return Type{};
}
template <typename F>
__host__ __device__ static void Else(F f)
{
// This is a trick for compiler:
// Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will
// use it,
// this will make "f" a generic lambda, so that "f" won't be compiled
// until being
// instantiated here
f(forwarder{});
}
};
template <bool predicate, class X, class Y>
struct conditional;
template <class X, class Y>
struct conditional<true, X, Y>
{
using type = X;
};
template <class X, class Y>
struct conditional<false, X, Y>
{
using type = Y;
};
template <bool predicate, class X, class Y>
using conditional_t = typename conditional<predicate, X, Y>::type;
} // namespace ck
#endif

View File

@@ -0,0 +1,48 @@
#ifndef CK_FUNCTIONAL2_HPP
#define CK_FUNCTIONAL2_HPP
#include "functional.hpp"
#include "sequence.hpp"
namespace ck {
namespace detail {
template <class>
struct static_for_impl;
template <index_t... Is>
struct static_for_impl<Sequence<Is...>>
{
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
swallow{(f(Number<Is>{}), 0)...};
}
};
} // namespace detail
// F signature: F(Number<Iter>)
template <index_t NBegin, index_t NEnd, index_t Increment>
struct static_for
{
__host__ __device__ constexpr static_for()
{
static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
"NBegin >= NEnd)");
}
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
detail::static_for_impl<typename arithmetic_sequence_gen<NBegin, NEnd, Increment>::type>{}(
f);
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,142 @@
#ifndef CK_FUNCTIONAL3_HPP
#define CK_FUNCTIONAL3_HPP
#include "functional.hpp"
#include "functional2.hpp"
#include "sequence.hpp"
#include "multi_index.hpp"
namespace ck {
namespace detail {
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
struct static_ford_impl
{
__host__ __device__ constexpr static_ford_impl()
{
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
}
// F signature: F(Sequence<...>)
// CurrentOrderedId: Sequence<...>
template <class F, class CurrentOrderedId>
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const
{
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
static_ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
f, CurrentOrderedId::PushBack(I));
});
}
};
template <class Orders>
struct static_ford_impl<Sequence<>, Orders>
{
// F signature: F(Sequence<...>)
// OrderedId: Sequence<...>
template <class F, class OrderedId>
__host__ __device__ constexpr void operator()(F f, OrderedId) const
{
// retrive unordered Id
f(OrderedId::ReorderGivenOld2New(Orders{}));
}
};
// RemainLengths: Sequence<...>
// Orders: Sequence<...>
template <class RemainLengths, class Orders>
struct ford_impl
{
__host__ __device__ constexpr ford_impl()
{
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
}
// F signature: F(Array<...> multi_id)
// CurrentOrderdId: Array<...>
template <class F, class CurrentOrderedId>
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
{
for(index_t i = 0; i < RemainLengths::Front(); ++i)
{
ford_impl<decltype(RemainLengths::PopFront()), Orders>{}(
f, container_push_back(current_ordered_id, i));
}
}
};
template <class Orders>
struct ford_impl<Sequence<>, Orders>
{
// F signature: F(Array<...> multi_id)
// CurrentOrderdId: Array<...>
template <class F, class CurrentOrderedId>
__host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const
{
// retrive unordered Id
f(container_reorder_given_old2new(current_ordered_id, Orders{}));
}
};
} // namespace detail
// Lengths is Sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct static_ford
{
__host__ __device__ constexpr static_ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Sequence<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, Sequence<>{});
}
};
// Lengths is Sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is Sequence<...>, it is the order of dimension in which ford will loop
// over each
// dimension
template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::GetSize(), 1>::type>
struct ford
{
__host__ __device__ constexpr ford()
{
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
}
// F signature: F(Array<...> multi_id)
// multi_id is the unordered multi-index
template <class F>
__host__ __device__ constexpr void operator()(F f) const
{
constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{});
for(index_t i = 0; i < ordered_lengths.Front(); ++i)
{
detail::ford_impl<decltype(ordered_lengths.PopFront()), Orders>{}(f,
make_multi_index(i));
}
}
};
} // namespace ck
#endif

View File

@@ -0,0 +1,62 @@
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
#include "sequence.hpp"
#include "tuple.hpp"
#include "array.hpp"
namespace ck {
namespace detail {
template <typename Indices>
struct unpack_impl;
template <index_t... Is>
struct unpack_impl<Sequence<Is...>>
{
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const
{
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...);
}
};
template <typename Seq0, typename Seq1>
struct unpack2_impl;
// TODO: remove this, after properly implementing unpack that takes any number of containers
template <index_t... Is, index_t... Js>
struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
{
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...,
std::forward<Y>(y).At(Number<Js>{})...);
}
};
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
{
using X_ = remove_reference_t<X>;
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
}
} // namespace ck
#endif

View File

@@ -0,0 +1,207 @@
#ifndef CK_INNER_PRODUCT_HPP
#define CK_INNER_PRODUCT_HPP
#include "data_type.hpp"
namespace ck {
template <typename TA, typename TB, typename TC>
__device__ void inner_product(const TA& a, const TB& b, TC& c);
template <>
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c += a * b;
#endif
}
template <>
__device__ void
inner_product<float2_t, float2_t, float>(const float2_t& a, const float2_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<float, 2>{a}.AsType<float>()[I0],
vector_type<float, 2>{b}.AsType<float>()[I0],
c);
inner_product(vector_type<float, 2>{a}.AsType<float>()[I1],
vector_type<float, 2>{b}.AsType<float>()[I1],
c);
}
template <>
__device__ void
inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
inner_product(vector_type<float, 4>{a}.AsType<float>()[I0],
vector_type<float, 4>{b}.AsType<float>()[I0],
c);
inner_product(vector_type<float, 4>{a}.AsType<float>()[I1],
vector_type<float, 4>{b}.AsType<float>()[I1],
c);
inner_product(vector_type<float, 4>{a}.AsType<float>()[I2],
vector_type<float, 4>{b}.AsType<float>()[I2],
c);
inner_product(vector_type<float, 4>{a}.AsType<float>()[I3],
vector_type<float, 4>{b}.AsType<float>()[I3],
c);
}
template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm volatile("\n \
v_dot2_f32_f16 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c = __builtin_amdgcn_sdot2(a, b, c, false);
#endif
#else
const auto convert = type_convert<int32_t>{};
const vector_type<half_t, 2> a_vector{a};
const vector_type<half_t, 2> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) {
c += convert(a_vector.AsType<half_t>()[i]) * convert(b_vector.AsType<half_t>()[i]);
});
#endif
}
template <>
__device__ void inner_product<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
c);
inner_product(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
c);
}
template <>
__device__ void inner_product<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
c);
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
c);
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
c);
inner_product(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
c);
}
template <>
__device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{
#if defined(CK_USE_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
"
: "=v"(c)
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
#else
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
#endif
#else
const auto convert = type_convert<int32_t>{};
const vector_type<int8_t, 4> a_vector{a};
const vector_type<int8_t, 4> b_vector{b};
static_for<0, 4, 1>{}([&](auto i) {
c += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
});
#endif
}
template <>
__device__ void
inner_product<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a, const int8x8_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
c);
inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
c);
}
template <>
__device__ void
inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t& b, int32_t& c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
c);
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
c);
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
c);
inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
c);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,17 @@
#ifndef CK_INTEGRAL_CONSTANT_HPP
#define CK_INTEGRAL_CONSTANT_HPP
namespace ck {
template <class T, T v>
struct integral_constant
{
static constexpr T value = v;
typedef T value_type;
typedef integral_constant type;
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
};
} // namespace ck
#endif

View File

@@ -0,0 +1,155 @@
#ifndef CK_MAGIC_DIVISION_HPP
#define CK_MAGIC_DIVISION_HPP
#include "config.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
namespace ck {
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct MagicDivision
{
// uint32_t
__host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor)
{
// assert(divisior >= 1 && divisior <= INT32_MAX);
uint32_t shift = 0;
for(shift = 0; shift < 32; ++shift)
{
if((1U << shift) >= divisor)
{
break;
}
}
uint64_t one = 1;
uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
// assert(multiplier <= 0xffffffffUL);
return make_tuple(uint32_t(multiplier), shift);
}
__host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor)
{
auto tmp = CalculateMagicNumbers(divisor);
return tmp[Number<0>{}];
}
__host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor)
{
auto tmp = CalculateMagicNumbers(divisor);
return tmp[Number<1>{}];
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<uint32_t, Divisor>)
{
constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[Number<0>{}];
constexpr uint32_t shift = tmp[Number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>)
{
constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});
return integral_constant<uint32_t, multiplier>{};
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<uint32_t, Divisor>)
{
constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});
return integral_constant<uint32_t, shift>{};
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<int32_t, Divisor>)
{
return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
}
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<int32_t, Divisor>)
{
return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
}
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<int32_t, Divisor>)
{
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
}
// magic division for uint32_t
__host__ __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32;
return (tmp + dividend) >> shift;
}
#if 1 // debug
// HACK: magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
__host__ __device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
uint32_t tmp =
(static_cast<uint64_t>(dividend_u32) * static_cast<uint64_t>(multiplier)) >> 32;
return (tmp + dividend_u32) >> shift;
}
#else
// the inline ASM is producing wrong result
__host__ __device__ static int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t r;
asm volatile("\n \
v_mul_hi_u32 %0, %1, %2 \n \
v_add_u32_e32 %0, %1, %0 \n \
v_lshrrev_b32_e32 %0, %3, %0 \n \
"
: "=v"(r)
: "v"(as_type<uint32_t>(dividend_i32)), "s"(multiplier), "s"(shift));
return as_type<int32_t>(r);
}
#endif
};
} // namespace ck
#endif

View File

@@ -0,0 +1,216 @@
#ifndef CK_MATH_HPP
#define CK_MATH_HPP
#include "config.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "enable_if.hpp"
namespace ck {
namespace math {
template <typename T, T s>
struct scales
{
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
};
template <typename T>
struct plus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct minus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
};
struct multiplies
{
template <typename A, typename B>
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const
{
return a * b;
}
};
template <typename T>
struct maximize
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};
template <typename T>
struct minimize
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
};
template <typename T>
struct integer_divide_ceiler
{
__host__ __device__ constexpr T operator()(T a, T b) const
{
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return (a + b - Number<1>{}) / b;
}
};
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
}
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{
return (x + y - Number<1>{}) / y;
}
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
{
return y * integer_divide_ceil(x, y);
}
template <typename T>
__host__ __device__ constexpr T max(T x)
{
return x;
}
template <typename T>
__host__ __device__ constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <index_t X>
__host__ __device__ constexpr index_t max(Number<X>, index_t y)
{
return X > y ? X : y;
}
template <index_t Y>
__host__ __device__ constexpr index_t max(index_t x, Number<Y>)
{
return x > Y ? x : Y;
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto max(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
}
template <typename T>
__host__ __device__ constexpr T min(T x)
{
return x;
}
template <typename T>
__host__ __device__ constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <index_t X>
__host__ __device__ constexpr index_t min(Number<X>, index_t y)
{
return X < y ? X : y;
}
template <index_t Y>
__host__ __device__ constexpr index_t min(index_t x, Number<Y>)
{
return x < Y ? x : Y;
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto min(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...));
}
// greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{
if(x < 0)
{
return gcd(-x, y);
}
else if(y < 0)
{
return gcd(x, -y);
}
else if(x == y || x == 0)
{
return y;
}
else if(y == 0)
{
return x;
}
else if(x > y)
{
return gcd(x % y, y);
}
else
{
return gcd(x, y % x);
}
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
{
constexpr auto r = gcd(X, Y);
return Number<r>{};
}
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
{
return gcd(x, gcd(ys...));
}
// least common multiple
template <typename X, typename Y>
__host__ __device__ constexpr auto lcm(X x, Y y)
{
return (x * y) / gcd(x, y);
}
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
{
return lcm(x, lcm(ys...));
}
template <typename T>
struct equal
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
};
template <typename T>
struct less
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};
} // namespace math
} // namespace ck
#endif

View File

@@ -0,0 +1,12 @@
#ifndef CK_MULTI_INDEX_HPP
#define CK_MULTI_INDEX_HPP
#include "common_header.hpp"
#if CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX
#include "array_multi_index.hpp"
#else
#include "statically_indexed_array_multi_index.hpp"
#endif
#endif

View File

@@ -0,0 +1,44 @@
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
#include "integral_constant.hpp"
namespace ck {
template <index_t N>
using Number = integral_constant<index_t, N>;
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator+(Number<X>, Number<Y>)
{
return Number<X + Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator-(Number<X>, Number<Y>)
{
static_assert(Y <= X, "wrong!");
return Number<X - Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator*(Number<X>, Number<Y>)
{
return Number<X * Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator/(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X / Y>{};
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto operator%(Number<X>, Number<Y>)
{
static_assert(Y > 0, "wrong!");
return Number<X % Y>{};
}
} // namespace ck
#endif

View File

@@ -0,0 +1,22 @@
#ifndef CK_PRINT_HPP
#define CK_PRINT_HPP
#include "array.hpp"
#include "statically_indexed_array.hpp"
#include "container_helper.hpp"
#include "sequence.hpp"
namespace ck {
template <typename T>
__host__ __device__ void print_array(const char* s, T a)
{
constexpr index_t nsize = a.Size();
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
} // namespace ck
#endif

View File

@@ -0,0 +1,880 @@
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP
#include "integral_constant.hpp"
#include "type.hpp"
#include "functional.hpp"
#include "math.hpp"
namespace ck {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct Sequence;
template <typename Seq, index_t I>
struct sequence_split;
template <typename>
struct sequence_reverse;
template <typename>
struct sequence_map_inverse;
template <typename>
struct is_valid_sequence_map;
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>);
template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq);
template <index_t... Is>
struct Sequence
{
using Type = Sequence;
using data_type = index_t;
static constexpr index_t mSize = sizeof...(Is);
__host__ __device__ static constexpr auto Size() { return Number<mSize>{}; }
__host__ __device__ static constexpr auto GetSize() { return Size(); }
__host__ __device__ static constexpr index_t At(index_t I)
{
// the last dummy element is to prevent compiler complain about empty array, when mSize = 0
const index_t mData[mSize + 1] = {Is..., 0};
return mData[I];
}
template <index_t I>
__host__ __device__ static constexpr auto At(Number<I>)
{
static_assert(I < mSize, "wrong! I too large");
return Number<At(I)>{};
}
template <index_t I>
__host__ __device__ static constexpr auto Get(Number<I>)
{
return At(Number<I>{});
}
template <typename I>
__host__ __device__ constexpr auto operator[](I i) const
{
return At(i);
}
template <index_t... IRs>
__host__ __device__ static constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/)
{
static_assert(sizeof...(Is) == sizeof...(IRs),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map");
return Sequence<Type::At(Number<IRs>{})...>{};
}
// MapOld2New is Sequence<...>
template <typename MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
{
static_assert(MapOld2New::Size() == Size(),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value, "wrong! invalid reorder map");
return ReorderGivenNew2Old(typename sequence_map_inverse<MapOld2New>::type{});
}
__host__ __device__ static constexpr auto Reverse()
{
return typename sequence_reverse<Type>::type{};
}
__host__ __device__ static constexpr auto Front()
{
static_assert(mSize > 0, "wrong!");
return At(Number<0>{});
}
__host__ __device__ static constexpr auto Back()
{
static_assert(mSize > 0, "wrong!");
return At(Number<mSize - 1>{});
}
__host__ __device__ static constexpr auto PopFront() { return sequence_pop_front(Type{}); }
__host__ __device__ static constexpr auto PopBack() { return sequence_pop_back(Type{}); }
template <index_t... Xs>
__host__ __device__ static constexpr auto PushFront(Sequence<Xs...>)
{
return Sequence<Xs..., Is...>{};
}
template <index_t... Xs>
__host__ __device__ static constexpr auto PushFront(Number<Xs>...)
{
return Sequence<Xs..., Is...>{};
}
template <index_t... Xs>
__host__ __device__ static constexpr auto PushBack(Sequence<Xs...>)
{
return Sequence<Is..., Xs...>{};
}
template <index_t... Xs>
__host__ __device__ static constexpr auto PushBack(Number<Xs>...)
{
return Sequence<Is..., Xs...>{};
}
template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Number<Ns>...)
{
return Sequence<Type::At(Number<Ns>{})...>{};
}
template <index_t... Ns>
__host__ __device__ static constexpr auto Extract(Sequence<Ns...>)
{
return Sequence<Type::At(Number<Ns>{})...>{};
}
template <index_t I, index_t X>
__host__ __device__ static constexpr auto Modify(Number<I>, Number<X>)
{
static_assert(I < Size(), "wrong!");
using seq_split = sequence_split<Type, I>;
constexpr auto seq_left = typename seq_split::left_type{};
constexpr auto seq_right = typename seq_split::right_type{}.PopFront();
return seq_left.PushBack(Number<X>{}).PushBack(seq_right);
}
template <typename F>
__host__ __device__ static constexpr auto Transform(F f)
{
return Sequence<f(Is)...>{};
}
__host__ __device__ static void Print()
{
printf("{");
printf("size %d, ", index_t{Size()});
static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); });
printf("}");
}
};
// merge sequence
template <typename Seq, typename... Seqs>
struct sequence_merge
{
using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type;
};
template <index_t... Xs, index_t... Ys>
struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Xs..., Ys...>;
};
template <typename Seq>
struct sequence_merge<Seq>
{
using type = Seq;
};
// generate sequence
template <index_t NSize, typename F>
struct sequence_gen
{
template <index_t IBegin, index_t NRemain, typename G>
struct sequence_gen_impl
{
static constexpr index_t NRemainLeft = NRemain / 2;
static constexpr index_t NRemainRight = NRemain - NRemainLeft;
static constexpr index_t IMiddle = IBegin + NRemainLeft;
using type = typename sequence_merge<
typename sequence_gen_impl<IBegin, NRemainLeft, G>::type,
typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 1, G>
{
static constexpr index_t Is = G{}(Number<I>{});
using type = Sequence<Is>;
};
template <index_t I, typename G>
struct sequence_gen_impl<I, 0, G>
{
using type = Sequence<>;
};
using type = typename sequence_gen_impl<0, NSize, F>::type;
};
// arithmetic sequence
template <index_t IBegin, index_t IEnd, index_t Increment>
struct arithmetic_sequence_gen
{
struct F
{
__host__ __device__ constexpr index_t operator()(index_t i) const
{
return i * Increment + IBegin;
}
};
using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type;
};
// uniform sequence
template <index_t NSize, index_t I>
struct uniform_sequence_gen
{
struct F
{
__host__ __device__ constexpr index_t operator()(index_t) const { return I; }
};
using type = typename sequence_gen<NSize, F>::type;
};
// reverse inclusive scan (with init) sequence
template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan;
template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I, Is...>, Reduce, Init>
{
using old_scan = typename sequence_reverse_inclusive_scan<Sequence<Is...>, Reduce, Init>::type;
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.Front());
using type = typename sequence_merge<Sequence<new_reduce>, old_scan>::type;
};
template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<I>, Reduce, Init>
{
using type = Sequence<Reduce{}(I, Init)>;
};
template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<Sequence<>, Reduce, Init>
{
using type = Sequence<>;
};
// split sequence
template <typename Seq, index_t I>
struct sequence_split
{
static constexpr index_t NSize = Seq{}.Size();
using range0 = typename arithmetic_sequence_gen<0, I, 1>::type;
using range1 = typename arithmetic_sequence_gen<I, NSize, 1>::type;
using left_type = decltype(Seq::Extract(range0{}));
using right_type = decltype(Seq::Extract(range1{}));
};
// reverse sequence
template <typename Seq>
struct sequence_reverse
{
static constexpr index_t NSize = Seq{}.Size();
using seq_split = sequence_split<Seq, NSize / 2>;
using type = typename sequence_merge<
typename sequence_reverse<typename seq_split::right_type>::type,
typename sequence_reverse<typename seq_split::left_type>::type>::type;
};
template <index_t I>
struct sequence_reverse<Sequence<I>>
{
using type = Sequence<I>;
};
template <index_t I0, index_t I1>
struct sequence_reverse<Sequence<I0, I1>>
{
using type = Sequence<I1, I0>;
};
#if 1
template <typename Reduce, typename Seq, typename... Seqs>
struct sequence_reduce
{
using type = typename sequence_reduce<Reduce,
Seq,
typename sequence_reduce<Reduce, Seqs...>::type>::type;
};
template <typename Reduce, index_t... Xs, index_t... Ys>
struct sequence_reduce<Reduce, Sequence<Xs...>, Sequence<Ys...>>
{
using type = Sequence<Reduce{}(Xs, Ys)...>;
};
template <typename Reduce, typename Seq>
struct sequence_reduce<Reduce, Seq>
{
using type = Seq;
};
#endif
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
{
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl
{
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
static constexpr index_t chosen_value =
choose_left ? LeftValues::Front() : RightValues::Front();
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
using new_left_values =
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
using new_left_ids =
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
using new_right_values =
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
using new_right_ids =
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
template <typename LeftValues,
typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
Sequence<>,
Sequence<>,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
};
template <typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<Sequence<>,
Sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
};
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge
{
using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
Sequence<>,
Sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
static constexpr index_t nsize = Values::Size();
using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using left_unsorted_values = typename split_unsorted_values::left_type;
using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
};
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
{
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
using sorted_values =
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
};
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
{
using sorted_values = Sequence<Value>;
using sorted_ids = Sequence<Id>;
};
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
{
using sorted_values = Sequence<>;
using sorted_ids = Sequence<>;
};
template <typename Values, typename Compare>
struct sequence_sort
{
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
};
template <typename Values, typename Less, typename Equal>
struct sequence_unique_sort
{
template <typename RemainValues,
typename RemainIds,
typename UniquifiedValues,
typename UniquifiedIds,
typename Eq>
struct sorted_sequence_uniquify_impl
{
static constexpr index_t current_value = RemainValues::Front();
static constexpr index_t current_id = RemainIds::Front();
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
using new_remain_values = decltype(RemainValues::PopFront());
using new_remain_ids = decltype(RemainIds::PopFront());
using new_uniquified_values =
typename conditional<is_unique_value,
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
UniquifiedValues>::type;
using new_uniquified_ids =
typename conditional<is_unique_value,
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
UniquifiedIds>::type;
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
new_remain_ids,
new_uniquified_values,
new_uniquified_ids,
Eq>;
// this is output
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
struct sorted_sequence_uniquify_impl<Sequence<>,
Sequence<>,
UniquifiedValues,
UniquifiedIds,
Eq>
{
using uniquified_values = UniquifiedValues;
using uniquified_ids = UniquifiedIds;
};
template <typename SortedValues, typename SortedIds, typename Eq>
struct sorted_sequence_uniquify
{
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
decltype(SortedIds::PopFront()),
Sequence<SortedValues::Front()>,
Sequence<SortedIds::Front()>,
Eq>;
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
using sort = sequence_sort<Values, Less>;
using sorted_values = typename sort::type;
using sorted_ids = typename sort::sorted2unsorted_map;
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
// this is output
using type = typename uniquify::uniquified_values;
using sorted2unsorted_map = typename uniquify::uniquified_ids;
};
template <typename SeqMap>
struct is_valid_sequence_map : is_same<typename arithmetic_sequence_gen<0, SeqMap::Size(), 1>::type,
typename sequence_sort<SeqMap, math::less<index_t>>::type>
{
};
template <typename SeqMap>
struct sequence_map_inverse
{
template <typename X2Y, typename WorkingY2X, index_t XBegin, index_t XRemain>
struct sequence_map_inverse_impl
{
static constexpr auto new_y2x =
WorkingY2X::Modify(X2Y::At(Number<XBegin>{}), Number<XBegin>{});
using type =
typename sequence_map_inverse_impl<X2Y, decltype(new_y2x), XBegin + 1, XRemain - 1>::
type;
};
template <typename X2Y, typename WorkingY2X, index_t XBegin>
struct sequence_map_inverse_impl<X2Y, WorkingY2X, XBegin, 0>
{
using type = WorkingY2X;
};
using type =
typename sequence_map_inverse_impl<SeqMap,
typename uniform_sequence_gen<SeqMap::Size(), 0>::type,
0,
SeqMap::Size()>::type;
};
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return Sequence<(Xs + Ys)...>{};
}
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return Sequence<(Xs - Ys)...>{};
}
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return Sequence<(Xs * Ys)...>{};
}
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return Sequence<(Xs / Ys)...>{};
}
template <index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size");
return Sequence<(Xs % Ys)...>{};
}
template <index_t... Xs, index_t Y>
__host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
{
return Sequence<(Xs + Y)...>{};
}
template <index_t... Xs, index_t Y>
__host__ __device__ constexpr auto operator-(Sequence<Xs...>, Number<Y>)
{
return Sequence<(Xs - Y)...>{};
}
template <index_t... Xs, index_t Y>
__host__ __device__ constexpr auto operator*(Sequence<Xs...>, Number<Y>)
{
return Sequence<(Xs * Y)...>{};
}
template <index_t... Xs, index_t Y>
__host__ __device__ constexpr auto operator/(Sequence<Xs...>, Number<Y>)
{
return Sequence<(Xs / Y)...>{};
}
template <index_t... Xs, index_t Y>
__host__ __device__ constexpr auto operator%(Sequence<Xs...>, Number<Y>)
{
return Sequence<(Xs % Y)...>{};
}
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator+(Number<Y>, Sequence<Xs...>)
{
return Sequence<(Y + Xs)...>{};
}
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator-(Number<Y>, Sequence<Xs...>)
{
return Sequence<(Y - Xs)...>{};
}
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator*(Number<Y>, Sequence<Xs...>)
{
return Sequence<(Y * Xs)...>{};
}
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator/(Number<Y>, Sequence<Xs...>)
{
return Sequence<(Y / Xs)...>{};
}
template <index_t Y, index_t... Xs>
__host__ __device__ constexpr auto operator%(Number<Y>, Sequence<Xs...>)
{
return Sequence<(Y % Xs)...>{};
}
template <index_t I, index_t... Is>
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
{
return Sequence<Is...>{};
}
template <typename Seq>
__host__ __device__ constexpr auto sequence_pop_back(Seq)
{
static_assert(Seq::Size() > 0, "wrong! cannot pop an empty Sequence!");
return sequence_pop_front(Seq::Reverse()).Reverse();
}
template <typename... Seqs>
__host__ __device__ constexpr auto merge_sequences(Seqs...)
{
return typename sequence_merge<Seqs...>::type{};
}
template <typename F, index_t... Xs>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>)
{
return Sequence<f(Xs)...>{};
}
template <typename F, index_t... Xs, index_t... Ys>
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
{
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
return Sequence<f(Xs, Ys)...>{};
}
template <typename F, index_t... Xs, index_t... Ys, index_t... Zs>
__host__ __device__ constexpr auto
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
{
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
"Dim not the same");
return Sequence<f(Xs, Ys, Zs)...>{};
}
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return typename sequence_reverse_inclusive_scan<Seq, Reduce, Init>::type{};
}
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number<Init>{})
.PushBack(Number<Init>{});
}
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number<Init>)
{
return reverse_inclusive_scan_sequence(Seq{}.Reverse(), Reduce{}, Number<Init>{}).Reverse();
}
template <typename Seq, index_t... Is>
__host__ __device__ constexpr auto pick_sequence_elements_by_ids(Seq, Sequence<Is...> /* ids */)
{
return Sequence<Seq::At(Number<Is>{})...>{};
}
#if 1
namespace detail {
template <typename WorkSeq, typename RemainSeq, typename RemainMask>
struct pick_sequence_elements_by_mask_impl
{
using new_work_seq = typename conditional<RemainMask::Front(),
decltype(WorkSeq::PushBack(RemainSeq::Front())),
WorkSeq>::type;
using type =
typename pick_sequence_elements_by_mask_impl<new_work_seq,
decltype(RemainSeq::PopFront()),
decltype(RemainMask::PopFront())>::type;
};
template <typename WorkSeq>
struct pick_sequence_elements_by_mask_impl<WorkSeq, Sequence<>, Sequence<>>
{
using type = WorkSeq;
};
} // namespace detail
template <typename Seq, typename Mask>
__host__ __device__ constexpr auto pick_sequence_elements_by_mask(Seq, Mask)
{
static_assert(Seq::Size() == Mask::Size(), "wrong!");
return typename detail::pick_sequence_elements_by_mask_impl<Sequence<>, Seq, Mask>::type{};
}
namespace detail {
template <typename WorkSeq, typename RemainValues, typename RemainIds>
struct modify_sequence_elements_by_ids_impl
{
using new_work_seq = decltype(WorkSeq::Modify(RemainIds::Front(), RemainValues::Front()));
using type =
typename modify_sequence_elements_by_ids_impl<new_work_seq,
decltype(RemainValues::PopFront()),
decltype(RemainIds::PopFront())>::type;
};
template <typename WorkSeq>
struct modify_sequence_elements_by_ids_impl<WorkSeq, Sequence<>, Sequence<>>
{
using type = WorkSeq;
};
} // namespace detail
template <typename Seq, typename Values, typename Ids>
__host__ __device__ constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids)
{
static_assert(Values::Size() == Ids::Size() && Seq::Size() >= Values::Size(), "wrong!");
return typename detail::modify_sequence_elements_by_ids_impl<Seq, Values, Ids>::type{};
}
#endif
template <typename Seq, typename Reduce, index_t Init>
__host__ __device__ constexpr index_t
reduce_on_sequence(Seq, Reduce f, Number<Init> /*initial_value*/)
{
index_t result = Init;
for(index_t i = 0; i < Seq::Size(); ++i)
{
result = f(result, Seq::At(i));
}
return result;
}
// TODO: a generic any_of for any container
template <typename Seq, typename F>
__host__ __device__ constexpr bool sequence_any_of(Seq, F f)
{
bool flag = false;
for(index_t i = 0; i < Seq::Size(); ++i)
{
flag = flag || f(Seq::At(i));
}
return flag;
}
// TODO: a generic all_of for any container
template <typename Seq, typename F>
__host__ __device__ constexpr bool sequence_all_of(Seq, F f)
{
bool flag = true;
for(index_t i = 0; i < Seq::Size(); ++i)
{
flag = flag && f(Seq::At(i));
}
return flag;
}
} // namespace ck
#endif

View File

@@ -0,0 +1,36 @@
#ifndef CK_SEQUENCE_HELPER_HPP
#define CK_SEQUENCE_HELPER_HPP
#include "tuple.hpp"
namespace ck {
template <index_t... Is>
__host__ __device__ constexpr auto make_sequence(Number<Is>...)
{
return Sequence<Is...>{};
}
// F returns index_t
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence(F, Number<N>)
{
return typename sequence_gen<N, F>::type{};
}
// F returns Number<>
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
{
return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
template <index_t... Is>
__host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
{
return Sequence<Is...>{};
}
} // namespace ck
#endif

View File

@@ -0,0 +1,71 @@
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace ck {
template <AddressSpaceEnum_t BufferAddressSpace,
typename T,
index_t N,
bool InvalidElementUseNumericalZeroValue>
struct StaticBuffer : public StaticallyIndexedArray<T, N>
{
using type = T;
using base = StaticallyIndexedArray<T, N>;
T invalid_element_value_ = T{0};
__host__ __device__ constexpr StaticBuffer() : base{} {}
__host__ __device__ constexpr StaticBuffer(T invalid_element_value)
: base{}, invalid_element_value_{invalid_element_value}
{
}
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
{
return BufferAddressSpace;
}
template <index_t I>
__host__ __device__ constexpr auto Get(Number<I> i, bool is_valid_element) const
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return is_valid_element ? At(i) : T{0};
}
else
{
return is_valid_element ? At(i) : invalid_element_value_;
}
}
template <index_t I>
__host__ __device__ void Set(Number<I> i, bool is_valid_element, const T& x)
{
if(is_valid_element)
{
At(i) = x;
}
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
};
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
{
return StaticBuffer<BufferAddressSpace, T, N, true>{};
}
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
__host__ __device__ constexpr auto make_static_buffer(Number<N>, T invalid_element_value)
{
return StaticBuffer<BufferAddressSpace, T, N, false>{invalid_element_value};
}
} // namespace ck
#endif

View File

@@ -0,0 +1,40 @@
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
#define CK_STATICALLY_INDEXED_ARRAY_HPP
#include "functional2.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
namespace ck {
namespace detail {
template <typename T, index_t NSize>
__host__ __device__ constexpr auto generate_same_type_tuple()
{
return generate_tuple([](auto) -> T { return T{}; }, Number<NSize>{});
}
template <typename T, index_t NSize>
using same_type_tuple = decltype(generate_same_type_tuple<T, NSize>());
} // namespace detail
template <typename T, index_t NSize>
using StaticallyIndexedArray = detail::same_type_tuple<T, NSize>;
template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs)
{
return StaticallyIndexedArray<X, sizeof...(Xs) + 1>(x, static_cast<X>(xs)...);
}
// make empty StaticallyIndexedArray
template <typename X>
__host__ __device__ constexpr auto make_statically_indexed_array()
{
return StaticallyIndexedArray<X, 0>();
}
} // namespace ck
#endif

View File

@@ -0,0 +1,108 @@
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp"
namespace ck {
template <index_t N>
using MultiIndex = StaticallyIndexedArray<index_t, N>;
template <typename... Xs>
__host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
{
return make_statically_indexed_array<index_t>(index_t{xs}...);
}
template <index_t NSize>
__host__ __device__ constexpr auto make_zero_multi_index()
{
return unpack([](auto... xs) { return make_multi_index(xs...); },
typename uniform_sequence_gen<NSize, 0>::type{});
}
template <typename T>
__host__ __device__ constexpr auto to_multi_index(const T& x)
{
return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
}
// Here should use MultiIndex<NSize>, instead of Tuple<Ys...>, although the former
// is the alias of the latter. This is because compiler cannot infer the NSize if
// using MultiIndex<NSize>
// TODO: how to fix this?
template <typename... Ys, typename X>
__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; });
return y;
}
template <typename... Ys, typename X>
__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; });
return y;
}
template <typename... Xs, typename Y>
__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] + y[i]; });
return r;
}
template <typename... Xs, typename Y>
__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] - y[i]; });
return r;
}
template <typename... Xs, typename Y>
__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y[i]; });
return r;
}
// MultiIndex = index_t * MultiIndex
template <typename... Xs>
__host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
{
constexpr index_t NSize = sizeof...(Xs);
Tuple<Xs...> r;
static_for<0, NSize, 1>{}([&](auto i) { r(i) = a * x[i]; });
return r;
}
template <typename... Xs>
__host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
{
printf("{");
printf("MultiIndex, ");
printf("size %d,", index_t{sizeof...(Xs)});
static_for<0, sizeof...(Xs), 1>{}(
[&](auto i) { printf("%d ", static_cast<index_t>(x.At(i))); });
printf("}");
}
} // namespace ck
#endif

View File

@@ -0,0 +1,21 @@
#ifndef CK_SYNCHRONIZATION_AMD_HPP
#define CK_SYNCHRONIZATION_AMD_HPP
#include "config.hpp"
namespace ck {
__device__ void block_sync_lds()
{
#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
#else
__syncthreads();
#endif
}
} // namespace ck
#endif

View File

@@ -0,0 +1,166 @@
#ifndef CK_TUPLE_HPP
#define CK_TUPLE_HPP
#include "integral_constant.hpp"
#include "sequence.hpp"
#include "type.hpp"
#include "enable_if.hpp"
namespace ck {
namespace detail {
template <index_t>
struct TupleElementKey
{
__host__ __device__ constexpr TupleElementKey() = default;
};
template <typename Key, typename Data>
struct TupleElement
{
__host__ __device__ constexpr TupleElement() = default;
template <typename T,
typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
{
}
Data mData;
};
template <typename Key, typename Data>
__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement<Key, Data>& x)
{
return static_cast<const Data&>(x.mData);
}
template <typename Key, typename Data>
__host__ __device__ constexpr Data& get_tuple_element(TupleElement<Key, Data>& x)
{
return x.mData;
}
// TODO: not sure the use of reference is correct
template <typename Key, typename Data>
__host__ __device__ constexpr Data&& get_tuple_element(TupleElement<Key, Data>&& x)
{
return static_cast<Data&&>(x.mData);
}
template <typename Indices, typename... Xs>
struct TupleImpl;
template <index_t... Is, typename... Xs>
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
{
__host__ __device__ constexpr TupleImpl() = default;
template <typename Y,
typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{
}
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size");
}
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
__host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey<I>) const
{
return get_tuple_element<TupleElementKey<I>>(*this);
}
template <index_t I>
__host__ __device__ constexpr auto& GetElementByKey(TupleElementKey<I>)
{
return get_tuple_element<TupleElementKey<I>>(*this);
}
};
} // namespace detail
template <typename... Xs>
struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>
{
using base =
detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(Xs), 1>::type, Xs...>;
__host__ __device__ constexpr Tuple() = default;
template <typename Y,
typename enable_if<sizeof...(Xs) == 1 &&
!is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
{
}
template <typename... Ys,
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
{
}
__host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); }
template <index_t I>
__host__ __device__ constexpr const auto& At(Number<I>) const
{
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
}
template <index_t I>
__host__ __device__ constexpr auto& At(Number<I>)
{
static_assert(I < base::Size(), "wrong! out of range");
return base::GetElementByKey(detail::TupleElementKey<I>{});
}
template <index_t I>
__host__ __device__ constexpr const auto& operator[](Number<I> i) const
{
return At(i);
}
template <index_t I>
__host__ __device__ constexpr auto& operator()(Number<I> i)
{
return At(i);
}
template <typename T>
__host__ __device__ constexpr auto operator=(const T& a)
{
static_assert(T::Size() == Size(), "wrong! size not the same");
static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; });
return *this;
}
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
};
template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{
return Tuple<remove_cv_t<remove_reference_t<Xs>>...>(std::forward<Xs>(xs)...);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,80 @@
#ifndef CK_TUPLE_HELPER_HPP
#define CK_TUPLE_HELPER_HPP
#include "functional4.hpp"
#include "tuple.hpp"
namespace ck {
template <typename... Ts>
struct is_known_at_compile_time<Tuple<Ts...>>
{
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return container_reduce(
Tuple<Ts...>{},
[](auto x, bool r) {
return is_known_at_compile_time<
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
r;
},
true);
}
static constexpr bool value = IsKnownAtCompileTime();
};
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
{
return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
typename arithmetic_sequence_gen<0, N, 1>::type{});
}
namespace detail {
template <typename F, typename X, index_t... Is>
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
}
template <typename F, typename X, typename Y, typename Z, index_t... Is>
__host__ __device__ constexpr auto
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
{
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
}
} // namespace detail
template <typename F, typename X>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
{
return detail::transform_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
{
return detail::transform_tuples_impl(
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <typename F, typename X, typename Y, typename Z>
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
{
return detail::transform_tuples_impl(
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
} // namespace ck
#endif

View File

@@ -0,0 +1,56 @@
#ifndef CK_TYPE_HPP
#define CK_TYPE_HPP
#include "integral_constant.hpp"
#include "enable_if.hpp"
namespace ck {
template <typename X, typename Y>
struct is_same : public integral_constant<bool, false>
{
};
template <typename X>
struct is_same<X, X> : public integral_constant<bool, true>
{
};
template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type;
template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type;
template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
template <typename T>
struct is_known_at_compile_time;
template <>
struct is_known_at_compile_time<index_t>
{
static constexpr bool value = false;
};
template <typename T, T X>
struct is_known_at_compile_time<integral_constant<T, X>>
{
static constexpr bool value = true;
};
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y as_type(X x)
{
union AsType
{
X x;
Y y;
};
return AsType{x}.y;
}
} // namespace ck
#endif

View File

@@ -0,0 +1,14 @@
#ifndef CK_UTILITY_HPP
#define CK_UTILITY_HPP
#include "config.hpp"
namespace ck {
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; }
} // namespace ck
#endif

View File

@@ -0,0 +1,370 @@
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
using namespace ck;
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
constexpr index_t KPerThread = CK_PARAM_KPerThread;
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
using ABlockTransferThreadSliceLengths_K_M0_M1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K_M0_M1>;
using ABlockTransferThreadClusterLengths_K_M0_M1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K_M0_M1>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_M1 =
CK_PARAM_ABlockTransferDstScalarPerVector_M1;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_K_N0_N1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K_N0_N1>;
using BBlockTransferThreadClusterLengths_K_N0_N1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K_N0_N1>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_N1 =
CK_PARAM_BBlockTransferDstScalarPerVector_N1;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
int n,
int c,
int hi,
int wi,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_k_m0_m1_grid_desc,
void* p_b_k_n0_n1_grid_desc,
void* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
const auto descs = transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW));
const auto a_k_m_grid_desc = descs[I0];
const auto b_k_n_grid_desc = descs[I1];
const auto c_m_n_grid_desc = descs[I2];
using AKMGridDesc = decltype(a_k_m_grid_desc);
using BKNGridDesc = decltype(b_k_n_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})));
using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm =
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
AKMGridDesc,
BKNGridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
auto c_m0_m10_m11_n0_n10_n11_grid_desc =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<decltype(a_k_m0_m1_grid_desc)*>(p_a_k_m0_m1_grid_desc) = a_k_m0_m1_grid_desc;
*static_cast<decltype(b_k_n0_n1_grid_desc)*>(p_b_k_n0_n1_grid_desc) = b_k_n0_n1_grid_desc;
*static_cast<decltype(c_m0_m10_m11_n0_n10_n11_grid_desc)*>(
p_c_m0_m10_m11_n0_n10_n11_grid_desc) = c_m0_m10_m11_n0_n10_n11_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
};
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k_m0_m1_grid_desc,
const void CONSTANT* p_b_k_n0_n1_grid_desc,
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_c_hi_wi_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1));
constexpr auto a_k_m_grid_desc = descs[I0];
constexpr auto b_k_n_grid_desc = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2];
using AKMGridDesc = decltype(a_k_m_grid_desc);
using BKNGridDesc = decltype(b_k_n_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{})));
using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm =
GridwiseGemmDlops_km_kn_mn_v1r2<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
AKMGridDesc,
BKNGridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M0_M1,
ABlockTransferThreadClusterLengths_K_M0_M1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N0_N1,
BBlockTransferThreadClusterLengths_K_N0_N1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
constexpr auto a_k_m0_m1_grid_desc_tmp =
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
constexpr auto b_k_n0_n1_grid_desc_tmp =
GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
constexpr auto c_m0_m10_m11_n0_n10_n11_grid_desc_tmp =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc_tmp);
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc_tmp);
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k_m0_m1_grid_desc =
*reinterpret_cast<const AKM0M1GridDesc*>((const void*)p_a_k_m0_m1_grid_desc);
const auto b_k_n0_n1_grid_desc =
*reinterpret_cast<const BKN0N1GridDesc*>((const void*)p_b_k_n0_n1_grid_desc);
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k_m0_m1_grid_desc,
b_k_n0_n1_grid_desc,
c_m0_m10_m11_n0_n10_n11_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
};

View File

@@ -0,0 +1,358 @@
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
using namespace ck;
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t MPerWave = CK_PARAM_MPerWave;
constexpr index_t NPerWave = CK_PARAM_NPerWave;
constexpr index_t MRepeat = CK_PARAM_MRepeat;
constexpr index_t NRepeat = CK_PARAM_NRepeat;
constexpr index_t K1 = CK_PARAM_K1;
using ABlockTransferThreadSliceLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
using ABlockTransferThreadClusterLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
using BBlockTransferThreadClusterLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
int n,
int c,
int hi,
int wi,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(n, c, hi, wi));
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(k, c, y, x));
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(n, k, ho, wo));
const auto descs = transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW),
Number<K1>{});
const auto a_k0_m_k1_grid_desc = descs[I0];
const auto b_k0_n_k1_grid_desc = descs[I1];
const auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using AGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
false>;
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
a_k0_m_k1_grid_desc;
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
}
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_c_hi_wi_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
Number<K1>{});
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2];
using AGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using BGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
const auto b_k0_n_k1_grid_desc =
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
};

View File

@@ -0,0 +1,357 @@
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
using namespace ck;
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t MPerBlock = CK_PARAM_MPerBlock;
constexpr index_t NPerBlock = CK_PARAM_NPerBlock;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t MPerWave = CK_PARAM_MPerWave;
constexpr index_t NPerWave = CK_PARAM_NPerWave;
constexpr index_t MRepeat = CK_PARAM_MRepeat;
constexpr index_t NRepeat = CK_PARAM_NRepeat;
constexpr index_t K1 = CK_PARAM_K1;
using ABlockTransferThreadSliceLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_K0_M_K1>;
using ABlockTransferThreadClusterLengths_K0_M_K1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_K0_M_K1>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_K1 =
CK_PARAM_ABlockTransferDstScalarPerVector_K1;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_K0_N_K1>;
using BBlockTransferThreadClusterLengths_K0_N_K1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_K0_N_K1>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_K1 =
CK_PARAM_BBlockTransferDstScalarPerVector_K1;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
extern "C" __global__ void convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
int n,
int hi,
int wi,
int c,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_k0_m_k1_grid_desc,
void* p_b_k0_n_k1_grid_desc,
void* p_c_m0_m1_m2_n_grid_desc,
void* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(n, hi, wi, c));
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(make_tuple(k, y, x, c));
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(make_tuple(n, ho, wo, k));
const auto descs = transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(
in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW),
Number<K1>{});
const auto a_k0_m_k1_grid_desc = descs[I0];
const auto b_k0_n_k1_grid_desc = descs[I1];
const auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
false>;
auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
auto c_blockid_to_m0_n0_block_cluster_adaptor =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<remove_cv_t<decltype(a_k0_m_k1_grid_desc)>*>(p_a_k0_m_k1_grid_desc) =
a_k0_m_k1_grid_desc;
*static_cast<remove_cv_t<decltype(b_k0_n_k1_grid_desc)>*>(p_b_k0_n_k1_grid_desc) =
b_k0_n_k1_grid_desc;
*static_cast<decltype(c_m0_m1_m2_n_grid_desc)*>(p_c_m0_m1_m2_n_grid_desc) =
c_m0_m1_m2_n_grid_desc;
*static_cast<decltype(c_blockid_to_m0_n0_block_cluster_adaptor)*>(
p_c_blockid_to_m0_n0_block_cluster_adaptor) = c_blockid_to_m0_n0_block_cluster_adaptor;
}
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_hi_wi_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
constexpr auto wei_k_y_x_c_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 3, 3, 256));
constexpr auto out_n_ho_wo_k_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 28, 28, 256));
constexpr auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
Number<K1>{});
constexpr auto a_k0_m_k1_grid_desc_tmp = descs[I0];
constexpr auto b_k0_n_k1_grid_desc_tmp = descs[I1];
constexpr auto c_m_n_grid_desc = descs[I2];
using AK0MK1GridDesc = decltype(a_k0_m_k1_grid_desc_tmp);
using BK0NK1GridDesc = decltype(b_k0_n_k1_grid_desc_tmp);
using CMNGridDesc = decltype(c_m_n_grid_desc);
using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})));
using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridStepHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{})));
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
using BGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0>;
using GridwiseGemm =
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AK0MK1GridDesc,
BK0NK1GridDesc,
CMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerWave,
NPerWave,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_K0_M_K1,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K0_N_K1,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
false>;
constexpr auto c_m0_m1_m2_n_grid_desc_tmp =
GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
constexpr auto c_blockid_to_m0_n0_block_cluster_adaptor_tmp =
GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc_tmp);
using CBlockIdToM0N0BlockClusterAdaptor =
decltype(c_blockid_to_m0_n0_block_cluster_adaptor_tmp);
const auto a_k0_m_k1_grid_desc =
*reinterpret_cast<const AK0MK1GridDesc*>((const void*)p_a_k0_m_k1_grid_desc);
const auto b_k0_n_k1_grid_desc =
*reinterpret_cast<const BK0NK1GridDesc*>((const void*)p_b_k0_n_k1_grid_desc);
const auto c_m0_m1_m2_n_grid_desc =
*reinterpret_cast<const CM0M1M2NGridDesc*>((const void*)p_c_m0_m1_m2_n_grid_desc);
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseGemm::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_m1_m2_n_grid_desc,
c_blockid_to_m0_n0_block_cluster_adaptor);
};

View File

@@ -0,0 +1,405 @@
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_contraction_dlops_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
using namespace ck;
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr auto GN0 = Number<CK_PARAM_GN0>{};
constexpr auto GK1 = Number<CK_PARAM_GK1>{};
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
using BM10BN10ThreadClusterBM10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBM10Xs>;
using BM10BN10ThreadClusterBN10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBN10Xs>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>;
using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>;
using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>;
constexpr index_t CThreadTransferSrcDstVectorDim = 5;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBlockLoop);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
extern "C" __global__ void
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_,
int C_,
int Hi_,
int Wi_,
int K_,
int Y_,
int X_,
int ConvStrideH_,
int ConvStrideW_,
int ConvDilationH_,
int ConvDilationW_,
int InLeftPadH_,
int InLeftPadW_,
int InRightPadH_,
int InRightPadW_,
void* p_desc_tuple)
{
index_t N = static_cast<index_t>(N_);
index_t C = static_cast<index_t>(C_);
index_t Hi = static_cast<index_t>(Hi_);
index_t Wi = static_cast<index_t>(Wi_);
index_t K = static_cast<index_t>(K_);
index_t Y = static_cast<index_t>(Y_);
index_t X = static_cast<index_t>(X_);
index_t ConvStrideH = static_cast<index_t>(ConvStrideH_);
index_t ConvStrideW = static_cast<index_t>(ConvStrideW_);
index_t ConvDilationH = static_cast<index_t>(ConvDilationH_);
index_t ConvDilationW = static_cast<index_t>(ConvDilationW_);
index_t InLeftPadH = static_cast<index_t>(InLeftPadH_);
index_t InLeftPadW = static_cast<index_t>(InLeftPadW_);
index_t InRightPadH = static_cast<index_t>(InRightPadH_);
index_t InRightPadW = static_cast<index_t>(InRightPadW_);
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t Ho =
(Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1;
const index_t Wo =
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C, Hi, Wi));
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C, Y, X));
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo));
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(ConvStrideH, ConvStrideW),
make_tuple(ConvDilationH, ConvDilationW),
make_tuple(InLeftPadH, InLeftPadW),
make_tuple(InRightPadH, InRightPadW),
GN0,
GK1);
const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using CGridStepHacks = decltype(make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction =
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
{
auto desc_tuple =
make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
a_grid_desc_gk0_gm0_gm1_gk1),
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
b_grid_desc_gk0_gn0_gn1_gk1),
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1),
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1));
*static_cast<decltype(desc_tuple)*>(p_desc_tuple) = desc_tuple;
}
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_desc_tuple)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto in_n_c_hi_wi_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_naive_tensor_descriptor_packed(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
GN0,
GK1);
constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
using AGridStepHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using BGridStepHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using CGridStepHacks = decltype(make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
using AGridMoveSliceWindowStepHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowStepHacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction =
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
a_grid_desc_gk0_gm0_gm1_gk1));
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
b_grid_desc_gk0_gn0_gn1_gk1));
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1));
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1));
using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{},
BGridDesc_GK0_GN0_GN10_GN11_GK1{},
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
CGridBlockCluster_BlockId_To_GM10_GN10{}));
const auto desc_tuple = *reinterpret_cast<const DescTuple*>(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// TODO: how to cast?
(const void*)p_desc_tuple
#pragma clang diagnostic pop
);
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2];
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3];
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
};

125
external/rocm/include/bfloat16_dev.hpp vendored Normal file
View File

@@ -0,0 +1,125 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2019 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef BFLOAT16_DEVICE_HPP
#define BFLOAT16_DEVICE_HPP
#ifdef __cplusplus
extern "C" {
#endif
#ifdef __HIP_PLATFORM_HCC__
#define EXECUTION_SPECIFIER __device__
#else
#define EXECUTION_SPECIFIER
#endif // MIOPEN_BACKEND_HIP
typedef union
{
uint u32;
ushort2 ushortx2;
// Composable kernels are written in HIP language. The language doesnt support
// ushort2.hi or ushort2.low.
#ifdef __HIP_PLATFORM_HCC__
ushort ushortvec[2];
#endif // MIOPEN_BACKEND_HIP
float f32;
} cvt_bf16_fp32_t;
EXECUTION_SPECIFIER float bfloat16_to_float(ushort src_val)
{
cvt_bf16_fp32_t target_val;
#ifdef __HIP_PLATFORM_HCC__
target_val.ushortx2 = make_ushort2(0, src_val);
#else
target_val.ushortx2 = (ushort2)(0, src_val);
#endif
return target_val.f32;
}
EXECUTION_SPECIFIER ushort float_to_bfloat16(float src_val)
{
cvt_bf16_fp32_t target_val;
target_val.f32 = src_val;
// BF16 round and NaN preservation code matches
// https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/include/rocblas_bfloat16.h
if((~target_val.u32 & 0x7f800000) == 0) // Inf or NaN
{
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bloat16's mantissa bits are all 0.
if((target_val.u32 & 0xffff) != 0)
{
target_val.u32 |= 0x10000; // Preserve signaling NaN
}
}
else
{
#ifdef MIOPEN_USE_RNE_BFLOAT16
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
#ifdef __HIP_PLATFORM_HCC__
target_val.u32 += (0x7fff + (target_val.ushortvec[1] & 1));
#else
target_val.u32 +=
(0x7fff + (target_val.ushortx2.hi & 1)); // Round to nearest, round to even
#endif // MIOPEN_BACKEND_HIP
#endif // MIOPEN_USE_RNE_BFLOAT16
}
#ifdef __HIP_PLATFORM_HCC__
return target_val.ushortvec[1];
#else
return target_val.ushortx2.hi;
#endif // MIOPEN_BACKEND_HIP
}
#ifdef __cplusplus
}
#endif
#endif // BFLOAT16_DEVICE_HPP

2
host/CMakeLists.txt Normal file
View File

@@ -0,0 +1,2 @@
add_subdirectory(host_tensor)
add_subdirectory(driver_offline)

View File

@@ -0,0 +1,21 @@
include_directories(BEFORE
include
${PROJECT_SOURCE_DIR}/host/host_tensor/include
${PROJECT_SOURCE_DIR}/host/solver/include
${PROJECT_SOURCE_DIR}/composable_kernel/include
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/rocm/include
)
set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp)
set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp)
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)

View File

@@ -0,0 +1,330 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
in_n_hi_wi_c_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
I0,
I0,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(out_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<2, 0, 1>,
Sequence<0, 2, 1>,
1,
GemmABlockTransferSrcScalarPerVector_GemmM,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<1, 3, 7, 0, 2, 4, 5, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
out_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
}

View File

@@ -0,0 +1,306 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
const Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs =
transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc,
wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
I0,
I0,
Number<GemmK1>{});
const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto in_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1
constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1
constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{};
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(out_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(in_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<2, 0, 1>,
Sequence<0, 2, 1>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>,
#else
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
#endif
7,
GemmCThreadTransferDstScalarPerVector,
decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(in_m0_m1_m2_n_grid_step_hacks),
decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
true // CAccessOrderMRepeatNRepeat
>(static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
out_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc,
out_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
in_m0_m1_m2_n_grid_step_hacks,
out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data());
}

View File

@@ -0,0 +1,201 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_dlops_v1r2.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn0_gemmn1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_dlops_v1r2<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlockM1,
GemmNPerBlockN1,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM1100,
GemmM11N11ThreadClusterN1100,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_K_M0_M1,
GemmABlockTransferThreadClusterLengths_K_M0_M1,
Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder
0, // ABlockTransferSrcVectorDim
GemmABlockTransferSrcScalarPerVector_K,
GemmABlockTransferDstScalarPerVector_M1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder
Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
GemmBBlockTransferSrcScalarPerVector_N1,
GemmBBlockTransferDstScalarPerVector_N1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11,
decltype(wei_gemmk_gemmm0_gemmn1_grid_step_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_step_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm0_gemmn1_grid_step_hacks,
in_gemmk_gemmn0_gemmn1_grid_step_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_step_hacks,
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,280 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs =
#if 1
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
#else
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
#endif
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
for(index_t i = 0; i < 5; ++i)
{
#if 0
float ave_time = launch_kernel_gemm_xdlops_v1
#else
float ave_time = launch_kernel_gemm_xdlops_v2
#endif
<BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(descs[I0]),
decltype(descs[I1]),
decltype(descs[I2]),
decltype(descs[I3]),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPack,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<1, 0, 2>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy, which will be fused
// with MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,273 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_dlops_v1r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 1;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 2;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 8, 4] for i8
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmK1 = 4;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
using GemmM11N11ThreadClusterM110Xs = Sequence<8, 2>;
using GemmM11N11ThreadClusterN110Xs = Sequence<8, 2>;
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>;
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0
Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10
Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10
Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11
Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_dlops_v1r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlockM1,
GemmNPerBlockN1,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM110Xs,
GemmM11N11ThreadClusterN110Xs,
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1,
GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1,
Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_N11,
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks),
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_step_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_step_hacks,
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -0,0 +1,197 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<1, 0, 2>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 0, 1, 2, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,229 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r2.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r2<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1>,
2,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -0,0 +1,302 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -0,0 +1,354 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths);
#if 0
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 0
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
wei_k_y_x_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN
Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN
Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1
constexpr auto out_m0_m1_m2_n_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat
Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat
Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves
Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves
Sequence<0, 0, 0, 0, 0>{}, // 4+: M0
Sequence<0, 0, 0, 0, 0>{}, // 5+: M1
Sequence<0, 0, 0, 0, 0>{}, // 6+: M2
Sequence<0, 0, 0, 0, 0>{}), // 7+: N1
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat
Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat
Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves
Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves
Sequence<0, 0, 0, 0, 0>{}, // 4-: M0
Sequence<0, 0, 0, 0, 0>{}, // 5-: M1
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmK1,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
7,
GemmCThreadTransferDstScalarPerVector,
decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks),
decltype(out_m0_m1_m2_n_grid_step_hacks),
decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks),
decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
in_gemmk0_gemmm_gemmk1_grid_desc,
wei_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
in_gemmk0_gemmm_gemmk1_grid_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_step_hacks,
out_m0_m1_m2_n_grid_step_hacks,
in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks,
wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = static_cast<float>((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -0,0 +1,190 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp"
#include "driver_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp"
template <typename TInWei,
ck::index_t InWeiVectorSize,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t /* nrepeat */)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = out_n_k_ho_wo_lengths[I0];
const auto K = out_n_k_ho_wo_lengths[I1];
const auto C = wei_k_c_y_x_lengths[I1];
const auto Hi = in_n_c_hi_wi_lengths[I2];
const auto Wi = in_n_c_hi_wi_lengths[I3];
const auto Ho = out_n_k_ho_wo_lengths[I2];
const auto Wo = out_n_k_ho_wo_lengths[I3];
const auto Y = wei_k_c_y_x_lengths[I2];
const auto X = wei_k_c_y_x_lengths[I3];
const auto C0 = C / Number<InWeiVectorSize>{};
const auto C1 = Number<InWeiVectorSize>{};
const auto K0 = K / Number<InWeiVectorSize>{};
const auto K1 = Number<InWeiVectorSize>{};
Tensor<TInWei> in_n_c0_hi_wi_c1(
HostTensorDescriptor(std::initializer_list<index_t>{N, C0, Hi, Wi, C1}));
Tensor<TInWei> wei_k_c0_y_x_c1(
HostTensorDescriptor(std::initializer_list<index_t>{K, C0, Y, X, C1}));
Tensor<TOut> out_n_k0_ho_wo_k1(
HostTensorDescriptor(std::initializer_list<index_t>{N, K0, Ho, Wo, K1}));
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
in_n_c_hi_wi(n, c, hi, wi);
};
auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) {
wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) =
wei_k_c_y_x(k, c, y, x);
};
make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)();
make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)();
DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) *
in_n_c0_hi_wi_c1.mDesc.GetElementSpace());
DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace());
DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) *
out_n_k0_ho_wo_k1.mDesc.GetElementSpace());
in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
const auto in_n_c0_hi_wi_desc = make_naive_tensor_descriptor_packed(make_tuple(N, C0, Hi, Wi));
const auto wei_k_c0_y_x_desc = make_naive_tensor_descriptor_packed(make_tuple(K, C0, Y, X));
const auto out_n_k0_ho_wo_k1_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1));
#if 1
// cdata = 64, BlockSize = 64, 16x8x32x4
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = KPerBlock;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = 16;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#else
constexpr index_t BlockSize = 64;
constexpr index_t KPerBlock = 16;
constexpr index_t HoPerBlock = 8;
constexpr index_t WoPerBlock = 32;
constexpr index_t EPerBlock = 1;
constexpr index_t KPerThread = 16;
constexpr index_t HoPerThread = 2;
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, 16>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
#endif
constexpr auto conv_driver =
#if 0
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
#else
DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
#endif
<BlockSize,
typename vector_type<TInWei, InWeiVectorSize>::type,
TAcc,
TOut,
KPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
BThreadTransferSrcScalarPerVector_W,
CThreadTransferDstScalarPerVector_W>{};
conv_driver.Run(wei_k_c0_y_x_desc,
in_n_c0_hi_wi_desc,
out_n_k0_ho_wo_k1_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()),
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer()));
out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
out_n_k_ho_wo(n, k, ho, wo) =
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
};
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
}

View File

@@ -0,0 +1,241 @@
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "driver_contraction_dlops_v1r2.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_desc_n_c_hi_wi = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths);
const auto wei_desc_k_c_y_x = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths);
const auto out_desc_n_k_ho_wo = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths);
#if 1
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GN0 = 4;
constexpr index_t GK1 = 1;
constexpr index_t GM1PerBlockGM11 = 128;
constexpr index_t GN1PerBlockGN11 = 32;
constexpr index_t GK0PerBlock = 8;
constexpr index_t BM1PerThreadBM11 = 4;
constexpr index_t BN1PerThreadBN11 = 4;
constexpr index_t BK0PerThread = 1;
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
#elif 1
// [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16
// cdata = 64, BlockSize = 256
constexpr index_t BlockSize = 256;
constexpr index_t GN0 = 4;
constexpr index_t GK1 = 2;
constexpr index_t GM1PerBlockGM11 = 128;
constexpr index_t GN1PerBlockGN11 = 32;
constexpr index_t GK0PerBlock = 8;
constexpr index_t BM1PerThreadBM11 = 4;
constexpr index_t BN1PerThreadBN11 = 4;
constexpr index_t BK0PerThread = 1;
using BM10BN10ThreadClusterBM10Xs = Sequence<8, 2>;
using BM10BN10ThreadClusterBN10Xs = Sequence<8, 2>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
#endif
const auto descs =
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x,
in_desc_n_c_hi_wi,
out_desc_n_k_ho_wo,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GN0>{},
Number<GK1>{});
const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_grid_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
constexpr auto in_grid_step_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
constexpr auto out_grid_step_hacks = make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
constexpr auto wei_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
constexpr auto in_grid_move_slice_window_step_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_contraction_dlops_v1r2<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperationEnum_t::Set,
decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
decltype(in_grid_desc_gk0_gn0_gn1_gk1),
decltype(out_grid_desc_gm0_gm1_gn0_gn1),
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
CThreadTransferDstScalarPerVector_BN1,
decltype(wei_grid_step_hacks),
decltype(in_grid_step_hacks),
decltype(out_grid_step_hacks),
decltype(wei_grid_move_slice_window_step_hacks),
decltype(in_grid_move_slice_window_step_hacks)>(
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_grid_desc_gk0_gm0_gm1_gk1,
in_grid_desc_gk0_gn0_gn1_gk1,
out_grid_desc_gm0_gm1_gn0_gn1,
wei_grid_step_hacks,
in_grid_step_hacks,
out_grid_step_hacks,
wei_grid_move_slice_window_step_hacks,
in_grid_move_slice_window_step_hacks,
nrepeat);
float perf = static_cast<float>(calculate_convolution_flops(
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,286 @@
#ifndef DRIVER_CONTRACTION_DLOPS_V1R2_HPP
#define DRIVER_CONTRACTION_DLOPS_V1R2_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_contraction_dlops_v1r2.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
ck::index_t GM1PerBlockGM11,
ck::index_t GN1PerBlockGN11,
ck::index_t GK0PerBlock,
ck::index_t BM1PerThreadBM11,
ck::index_t BN1PerThreadBN11,
ck::index_t BK0PerThread,
typename BM10BN10ThreadClusterBM10Xs,
typename BM10BN10ThreadClusterBN10Xs,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
typename CThreadTransferSrcDstAccessOrder,
ck::index_t CThreadTransferSrcDstVectorDim,
ck::index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks>
__host__ float
driver_contraction_dlops_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
ck::index_t nrepeat)
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction =
GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM10Xs,
BM10BN10ThreadClusterBN10Xs,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
{
throw std::runtime_error("wrong! "
"GridwiseContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
"GM0_GM1_GN0_GN1 has invalid setting");
}
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1);
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
// c_grid_block_cluster_blockid_to_gm10_gn10
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1);
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
const bool has_double_tail_k_block_loop =
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
{
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else
{
const auto kernel = kernel_contraction_dlops_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
return ave_time;
}
#endif

View File

@@ -0,0 +1,349 @@
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::index_t KPerBlock,
ck::index_t HoPerBlock,
ck::index_t WoPerBlock,
ck::index_t EPerBlock,
ck::index_t KPerThread,
ck::index_t HoPerThread,
ck::index_t WoPerThread,
ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
ck::index_t ABlockTransferSrcScalarPerVector_E,
ck::index_t ABlockTransferDstScalarPerVector_K,
ck::index_t BThreadTransferSrcScalarPerVector_W,
ck::index_t CThreadTransferDstScalarPerVector_W>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
{
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_e_k_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Ho),
make_pass_through_transform(Wo)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_k_n_ho_wo_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pass_through_transform(Ho),
make_pass_through_transform(Wo)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X;
if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 &&
(E % EPerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_e_k_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
constexpr auto b_e_n_ho_wo_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
#if 1
// GEMM
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_ho_wo_global_desc),
KPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
Sequence<1, 0>,
Sequence<1, 0>,
0,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 1>,
3,
BThreadTransferSrcScalarPerVector_W,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>,
0,
CThreadTransferDstScalarPerVector_W,
decltype(a_e_k_global_step_hacks),
decltype(b_e_n_ho_wo_global_step_hacks),
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
decltype(a_e_k_global_move_slice_window_step_hack),
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
<< std::endl;
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel = run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_ho_wo_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_ho_wo_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf =
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k0_ho_wo_k1_global_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
#endif
}
};
#endif

View File

@@ -0,0 +1,364 @@
#ifndef DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#define DRIVER_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_DLOPS_NCHW_KCYX_NKHW_OUTPAD_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_dlops_v2.hpp"
#include "gridwise_operation_wrapper.hpp"
template <ck::index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
ck::index_t KPerBlock,
ck::index_t HoPerBlock,
ck::index_t WoPerBlock,
ck::index_t EPerBlock,
ck::index_t KPerThread,
ck::index_t HoPerThread,
ck::index_t WoPerThread,
ck::index_t EPerThread,
typename ABlockTransferThreadSliceLengths_E_K,
typename ABlockTransferThreadClusterLengths_E_K,
ck::index_t ABlockTransferSrcScalarPerVector_E,
ck::index_t ABlockTransferDstScalarPerVector_K,
ck::index_t BThreadTransferSrcScalarPerVector_W,
ck::index_t CThreadTransferDstScalarPerVector_W>
struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outpad
{
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ void Run(const ck::TensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
const ck::TensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
const ck::TensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const FloatAB* __restrict__ p_wei_global,
const FloatAB* __restrict__ p_in_global,
FloatC* __restrict__ p_out_global) const
{
using namespace ck;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
const auto OutRightPadH = Hop - Ho;
const auto OutRightPadW = Wop - Wo;
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
<< std::endl;
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
<< std::endl;
// weight tensor
const auto wei_e_k_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
in_n_c_hi_wi_global_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
in_n_c_hip_wip_global_desc,
make_tuple(
make_pass_through_transform(N),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
const auto in_e_n_ho_wo_global_desc = transform_tensor_descriptor(
in_n_c_y_ho_x_wo_global_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N),
make_pass_through_transform(Hop),
make_pass_through_transform(Wop)),
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
// output tensor
const auto out_k_n_hop_wop_global_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N, K0, Ho, Wo, K1)),
make_tuple(make_merge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_pad_transform(Ho, 0, OutRightPadH),
make_pad_transform(Wo, 0, OutRightPadW)),
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto E = C * Y * X;
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
(E % EPerBlock) == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// hack to control index calculation when iterating over a_k_m_global tensor
constexpr auto a_e_k_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
constexpr auto a_e_k_global_move_slice_window_step_hack = Sequence<0, 0, 0>{};
constexpr auto b_e_n_ho_wo_global_step_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack for NKHW format
constexpr auto c_k_n_ho_wo_global_tensor_step_hacks =
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
// GEMM
using gridwise_gemm = GridwiseGemmDlops_km_kn_mn_v3<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperationEnum_t::Set,
decltype(wei_e_k_global_desc),
decltype(in_e_n_ho_wo_global_desc),
decltype(out_k_n_hop_wop_global_desc),
KPerBlock,
HoPerBlock,
WoPerBlock,
EPerBlock,
KPerThread,
HoPerThread,
WoPerThread,
EPerThread,
ABlockTransferThreadSliceLengths_E_K,
ABlockTransferThreadClusterLengths_E_K,
Sequence<1, 0>,
Sequence<1, 0>,
0,
ABlockTransferSrcScalarPerVector_E,
ABlockTransferDstScalarPerVector_K,
false, // don't move back src coordinate after threadwise copy
Sequence<0, 2, 3, 1>,
3,
BThreadTransferSrcScalarPerVector_W,
false, // don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence<0, 2, 3, 1>,
0,
CThreadTransferDstScalarPerVector_W,
decltype(a_e_k_global_step_hacks),
decltype(b_e_n_ho_wo_global_step_hacks),
decltype(c_k_n_ho_wo_global_tensor_step_hacks),
decltype(a_e_k_global_move_slice_window_step_hack),
decltype(b_e_n_ho_wo_global_move_slice_window_step_hack)>;
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
index_t nrepeat = 100;
for(index_t i = 0; i < 5; ++i)
{
std::cout << "Start running " << nrepeat << " times..." << std::endl;
KernelTimer timer;
timer.Start();
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
<< std::endl;
for(index_t j = 0; j < nrepeat; ++j)
{
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, true>{});
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, true>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, true>{},
integral_constant<bool, false>{});
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, true>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, true>{});
}
else
{
const auto kernel =
run_gridwise_operation<gridwise_gemm,
decltype(wei_e_k_global_desc),
const FloatAB*,
decltype(in_e_n_ho_wo_global_desc),
const FloatAB*,
decltype(out_k_n_hop_wop_global_desc),
FloatC*,
integral_constant<bool, false>,
integral_constant<bool, false>>;
launch_kernel(kernel,
dim3(GridSize),
dim3(BlockSize),
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
p_in_global,
out_k_n_hop_wop_global_desc,
p_out_global,
integral_constant<bool, false>{},
integral_constant<bool, false>{});
}
}
timer.End();
float ave_time = timer.GetElapsedTime() / nrepeat;
float perf =
static_cast<float>(calculate_convolution_flops(in_n_c_hi_wi_global_desc,
wei_k_c_y_x_global_desc,
out_n_k0_ho_wo_k1_global_desc)) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
};
#endif

Some files were not shown because too many files have changed in this diff Show More