reorganize files to prepare for MIOpen integration (#51)

* change olc cmake

* adding online compile to fwd-v4r5r2

* update scripts

* remane fwd-v4r5r2 to fwd-v6r1

* clean up

[ROCm/composable_kernel commit: 1264925422]
This commit is contained in:
Chao Liu
2021-07-18 00:43:05 -05:00
committed by GitHub
parent 440e2126bb
commit f94e566273
110 changed files with 2992 additions and 4982 deletions

View File

@@ -6,14 +6,14 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
include(TargetFlags)
include(AddKernels)
#c++
## C++
enable_language(CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
#OpenMP
## OpenMP
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
# workaround issue hipcc in rocm3.5 cannot find openmp
set(OpenMP_CXX "${CMAKE_CXX_COMPILER}")
@@ -35,56 +35,8 @@ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
link_libraries(${OpenMP_gomp_LIBRARY})
link_libraries(${OpenMP_pthread_LIBRARY})
#GPU backend
if(DEVICE_BACKEND STREQUAL "AMD")
find_package(HIP REQUIRED)
endif()
#
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/composable_kernel/include
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/half/include
${PROJECT_SOURCE_DIR}/driver/include
${PROJECT_BINARY_DIR}/composable_kernel/include/utility
)
if(DEVICE_BACKEND STREQUAL "AMD")
include_directories(BEFORE
${PROJECT_SOURCE_DIR}/external/rocm/include
)
endif()
if(DEVICE_BACKEND STREQUAL "AMD")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp")
configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp")
endif()
add_subdirectory(driver)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}")
if(DEVICE_BACKEND STREQUAL "AMD")
set(CONV_V2_SOURCE driver/conv_driver_v2.cpp)
set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp)
set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp)
endif()
add_executable(conv_driver_v2 ${CONV_V2_SOURCE})
add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE})
add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE})
target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/)
target_link_libraries(conv_driver_v2 PRIVATE modConv)
target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv)
target_link_libraries(conv_driver_v2_olc PRIVATE modConv)
## HIP
find_package(HIP REQUIRED)
message(STATUS "Build with HIP ${hip_VERSION}")
add_subdirectory(host)

View File

@@ -1,292 +0,0 @@
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r1.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_GM11,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_GN11,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
__host__ float
driver_dynamic_contraction_v1r1(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_GM11,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_GN11,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto K = a_gk_gm0_gm1_grid_desc.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_gk_gm0_gm1_grid_desc, b_gk_gn0_gn1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
{
throw std::runtime_error(
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
}
const auto a_gk_gm0_gm10_gm11_grid_desc =
GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc);
const auto b_gk_gn0_gn10_gn11_grid_desc =
GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc);
using AGKGM0GM10GM11GridDesc = decltype(a_gk_gm0_gm10_gm11_grid_desc);
using BGKGN0GN10GN11GridDesc = decltype(b_gk_gn0_gn10_gn11_grid_desc);
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
// c_blockid_to_gm10_gn10_block_cluster_adaptor
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(K);
const bool has_double_tail_k_block_loop =
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(K);
{
std::cout << "a_gk_gm0_gm10_gm11_grid_desc{" << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0)
<< ", " << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I1) << ", "
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I2) << ", "
<< a_gk_gm0_gm10_gm11_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "b_gk_gn0_gn10_gn11_grid_desc{" << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I0)
<< ", " << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I1) << ", "
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I2) << ", "
<< b_gk_gn0_gn10_gn11_grid_desc.GetLength(I3) << "}" << std::endl;
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
else
{
const auto kernel = kernel_dynamic_contraction_v1r1<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGKGM0GM10GM11GridDesc>,
remove_reference_t<BGKGN0GN10GN11GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
}
return ave_time;
}
} // namespace ck
#endif

View File

@@ -13,19 +13,19 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
index_t GK0PerBlock,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100,
index_t BM10BN10ThreadClusterBN100,
index_t BM10BN10ThreadClusterBM101,
index_t BM10BN10ThreadClusterBN101,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
@@ -52,9 +52,9 @@ __host__ float
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
const FloatAB* p_b_grid,
FloatC* p_c_grid,
const AGKGM0GM1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
const BGKGN0GN1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
@@ -71,79 +71,83 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
constexpr auto I5 = Number<5>{};
// GEMM
using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM100,
BM10BN10ThreadClusterBN100,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
if(!GridwiseContraction::CheckValidity(
a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
{
throw std::runtime_error(
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
throw std::runtime_error("wrong! "
"GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
"GM0_GM1_GN0_GN1 has invalid setting");
}
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc =
GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc);
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc =
GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc);
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc);
using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc);
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1);
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
// c_blockid_to_gm10_gn10_block_cluster_adaptor
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
// c_grid_block_cluster_blockid_to_gm10_gn10
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1);
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
@@ -151,41 +155,41 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
{
std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{"
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", "
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{"
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", "
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
}
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true,
true>;
@@ -198,21 +202,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
true,
false>;
@@ -225,21 +229,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_contraction_v1r1<
const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false,
true>;
@@ -252,21 +256,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
else
{
const auto kernel = kernel_dynamic_contraction_v1r1<
const auto kernel = kernel_dynamic_contraction_v1r2<
GridwiseContraction,
FloatAB,
FloatC,
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
false,
false>;
@@ -279,10 +283,10 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor);
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10);
}
return ave_time;

View File

@@ -1,387 +0,0 @@
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1
#define CK_DRIVER_DYNAMIC_GEMM_V1
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_gemm_v1r1.hpp"
namespace ck {
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
__host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global,
const FloatAB* p_b_global,
FloatC* p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks,
index_t nrepeat)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
const auto K = a_k_m_global_desc.GetLength(I0);
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{};
if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0))
{
throw std::runtime_error("wrong! GEMM size no divisible");
}
// GEMM
using gridwise_gemm =
GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGlobalDesc,
BGlobalDesc,
CGlobalDesc,
CBlockClusterDesc,
MPerBlock,
NPerBlock,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGlobalIteratorHacks,
BGlobalIteratorHacks,
CGlobalIteratorHacks,
AGlobalMoveSliceWindowIteratorHacks,
BGlobalMoveSliceWindowIteratorHacks>;
const auto GridSize = (M / MPerBlock) * (N / NPerBlock);
const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc);
}
return ave_time;
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc));
DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc));
DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc));
DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc));
a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc);
b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc);
c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc);
c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc);
float ave_time = 0;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
true,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
true>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
else
{
const auto kernel = kernel_dynamic_gemm_v1r1<gridwise_gemm,
FloatAB,
FloatAB,
FloatC,
remove_reference_t<AGlobalDesc>,
remove_reference_t<BGlobalDesc>,
remove_reference_t<CGlobalDesc>,
remove_reference_t<CBlockClusterDesc>,
false,
false>;
ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(GridSize),
dim3(BlockSize),
0,
0,
p_a_global,
p_b_global,
p_c_global,
(void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(),
(void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer());
}
return ave_time;
#endif
}
} // namespace ck
#endif

View File

@@ -1,125 +0,0 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <index_t N0_,
typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
__host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
// weight tensor
const auto wei_gk_gm0_gm1_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(C * Y * X)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1, 2>{}, Sequence<0>{}));
// input tensor
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hi_wi_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pass_through_transform(C),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
constexpr auto N0 = Number<N0_>{};
const auto N1 = N / N0;
const auto in_n0_n1_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
in_n_c_hip_wip_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(C),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}));
const auto in_gk_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
in_n0_n1_c_y_ho_x_wo_grid_desc,
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
make_pass_through_transform(N0),
make_merge_transform(make_tuple(N1, Ho, Wo))),
make_tuple(Sequence<2, 3, 5>{}, Sequence<0>{}, Sequence<1, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
// output tensor
const auto out_n_k_howo_grid_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(Number<N0>{}),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
return make_tuple(
wei_gk_gm0_gm1_grid_desc, in_gk_gn0_gn1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
}
} // namespace ck
#endif

View File

@@ -0,0 +1,129 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace ck {
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template <typename... Wei,
typename... In,
typename... Out,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t GemmK1Value>
__host__ __device__ constexpr auto
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_y_x_c_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_hi_wi_c_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_ho_wo_k_grid_desc,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<GemmK1Value>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto GemmK1 = Number<GemmK1Value>{};
const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0);
const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3);
const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3);
const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1);
const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2);
const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1);
const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2);
const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1);
const auto X = wei_k_y_x_c_grid_desc.GetLength(I2);
const auto ConvStrideH = conv_strides[I0];
const auto ConvStrideW = conv_strides[I1];
const auto ConvDilationH = conv_dilations[I0];
const auto ConvDilationW = conv_dilations[I1];
const auto InLeftPadH = in_left_pads[I0];
const auto InLeftPadW = in_left_pads[I1];
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
const auto GemmM = K;
const auto GemmN = N * Ho * Wo;
const auto GemmK = C * Y * X;
const auto GemmK0 = GemmK / GemmK1;
// weight tensor
const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)),
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
wei_gemmk_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// input tensor
const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_gemmk_gemmn_grid_desc =
transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor(
in_gemmk_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// output tensor
const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor(
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc);
}
} // namespace ck
#endif

View File

@@ -1,5 +1,5 @@
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
@@ -17,10 +17,10 @@ template <typename... Wei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t N0Value,
index_t C0Value>
typename N0Type,
typename C0Type>
__host__ __device__ constexpr auto
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
@@ -28,8 +28,8 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
Number<N0Value>,
Number<C0Value>)
const N0Type& N0,
const C0Type& C0)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
@@ -61,9 +61,6 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const auto InRightPadH = in_right_pads[I0];
const auto InRightPadW = in_right_pads[I1];
constexpr auto N0 = Number<N0Value>{};
constexpr auto C0 = Number<C0Value>{};
const auto N1 = N / N0;
const auto C1 = C / C0;
@@ -109,7 +106,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
out_n_k_howo_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
make_unmerge_transform(make_tuple(I1, K)),
make_pass_through_transform(Ho * Wo)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
@@ -119,7 +116,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
out_n0_n1_1_k_howo_grid_desc,
make_tuple(make_pass_through_transform(I1),
make_pass_through_transform(K),
make_pass_through_transform(Number<N0>{}),
make_pass_through_transform(N0),
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));

View File

@@ -4,7 +4,7 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp"
#include "threadwise_contraction.hpp"
namespace ck {

View File

@@ -4,43 +4,43 @@
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_gemm_v2.hpp"
#include "threadwise_contraction.hpp"
namespace ck {
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. AK0MK1BlockDesc is known at compile-time
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BK0NK1BlockDesc is known at compile-time
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize,
typename FloatA,
typename FloatB,
typename FloatC,
typename AK0MK1BlockDesc,
typename BK0NK1BlockDesc,
index_t M1PerThreadM11,
index_t N1PerThreadN11,
index_t KPerThread,
index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterM101,
index_t M1N1ThreadClusterN101,
index_t AThreadCopyScalarPerVector_M11,
index_t BThreadCopyScalarPerVector_N11,
typename std::enable_if<AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
typename ABlockDesc_BK0_BM_BK1,
typename BBlockDesc_BK0_BN_BK1,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100,
index_t BM10BN10ThreadClusterBN100,
index_t BM10BN10ThreadClusterBM101,
index_t BM10BN10ThreadClusterBN101,
index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11,
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false>
struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
using AIndex = MultiIndex<3>;
using BIndex = MultiIndex<3>;
@@ -51,138 +51,144 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr index_t K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t M = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t N = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0);
static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2);
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
static constexpr index_t M100 = M1N1ThreadClusterM100;
static constexpr index_t N100 = M1N1ThreadClusterN100;
static constexpr index_t BM100 = BM10BN10ThreadClusterBM100;
static constexpr index_t BN100 = BM10BN10ThreadClusterBN100;
static constexpr index_t M101 = M1N1ThreadClusterM101;
static constexpr index_t N101 = M1N1ThreadClusterN101;
static constexpr index_t BM101 = BM10BN10ThreadClusterBM101;
static constexpr index_t BN101 = BM10BN10ThreadClusterBN101;
static constexpr index_t M11 = M1PerThreadM11;
static constexpr index_t N11 = N1PerThreadN11;
static constexpr index_t BM11 = BM1PerThreadBM11;
static constexpr index_t BN11 = BN1PerThreadBN11;
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
static constexpr index_t BM1 =
BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11;
static constexpr index_t BN1 =
BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11;
static constexpr index_t M0 = M / M1;
static constexpr index_t N0 = N / N1;
static constexpr index_t BM0 = BM / BM1;
static constexpr index_t BN0 = BN / BN1;
__host__ __device__ static constexpr auto
MakeAK0M0M1K1BlockDescriptor(const AK0MK1BlockDesc& a_k0_m_k1_block_desc)
MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1)
{
const auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
a_k0_m_k1_block_desc,
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{})),
make_pass_through_transform(Number<K1>{})),
const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor(
a_block_desc_bk0_bm_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BM0>{}, Number<BM1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return a_k0_m0_m1_k1_block_desc;
return a_block_bk0_bm0_bm1_bk1;
}
__host__ __device__ static constexpr auto
MakeBK0N0N1K1BlockDescriptor(const BK0NK1BlockDesc& b_k0_n_k1_block_desc)
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1)
{
const auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
b_k0_n_k1_block_desc,
make_tuple(make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{})),
make_pass_through_transform(Number<K1>{})),
const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor(
b_block_desc_bk0_bn_bk1,
make_tuple(make_pass_through_transform(Number<BK0>{}),
make_unmerge_transform(make_tuple(Number<BN0>{}, Number<BN1>{})),
make_pass_through_transform(Number<BK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
return b_k0_n0_n1_k1_block_desc;
return b_block_desc_bk0_bn0_bn1_bk1;
}
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
__host__ __device__ static constexpr auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n =
make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
Number<BM0>{}, Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_unmerge_transform(make_tuple(
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
Number<BN0>{}, Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n;
}
__host__ __device__ static constexpr auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 =
make_single_stage_tensor_adaptor(
make_tuple(make_pass_through_transform(Number<M0>{}),
make_tuple(make_pass_through_transform(Number<BM0>{}),
make_unmerge_transform(
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
make_pass_through_transform(Number<N0>{}),
make_tuple(Number<BM100>{}, Number<BM101>{}, Number<BM11>{})),
make_pass_through_transform(Number<BN0>{}),
make_unmerge_transform(
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
make_tuple(Number<BN100>{}, Number<BN101>{}, Number<BN11>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1;
}
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
__host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1()
{
return Sequence<M0, M11, N0, N11>{};
return Sequence<BM0, BM11, BN0, BN11>{};
}
static constexpr auto a_k0_m0_m1_k1_block_desc_ =
MakeAK0M0M1K1BlockDescriptor(AK0MK1BlockDesc{});
static constexpr auto b_k0_n0_n1_k1_block_desc_ =
MakeBK0N0N1K1BlockDescriptor(BK0NK1BlockDesc{});
static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ =
MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{});
static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ =
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
public:
__device__ BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2()
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
__device__ BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id())},
a_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
b_thread_copy_{
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(BlockSize == M101 * M100 * N101 * N100,
static_assert(BlockSize == BM101 * BM100 * BN101 * BN100,
"wrong! blocksize and cluster size not consistent");
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!");
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) ==
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
"wrong! K dimension not consistent");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2, "wrong");
static_assert(BM0 == 2 && BN0 == 2, "wrong");
}
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
__device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id)
{
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr auto adaptor0 =
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1();
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
make_pass_through_transform(M0),
make_pass_through_transform(M11),
make_pass_through_transform(N0),
make_pass_through_transform(N11)),
make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)),
make_pass_through_transform(BM0),
make_pass_through_transform(BM11),
make_pass_through_transform(BN0),
make_pass_through_transform(BN11)),
make_tuple(
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
@@ -192,201 +198,203 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
}
template <typename CM0M1N0N1ThreadDesc,
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
typename ABlockBuffer,
typename BBlockBuffer,
typename CThreadBuffer>
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11& c_m0_m1_n0_n1_thread_desc,
const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
CThreadBuffer& c_thread_buf) const
{
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: remove this restriction
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
static_assert(BM0 == 2 && BN0 == 2 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 &&
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
"wrong");
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
a_k0_m0_m1_k1_thread_desc_.GetElementSpaceSize());
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
b_k0_n0_n1_k1_thread_desc_.GetElementSpaceSize());
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_gemm =
ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_k0_m0_m1_k1_thread_desc_),
decltype(b_k0_n0_n1_k1_thread_desc_),
CM0M1N0N1ThreadDesc,
Sequence<KPerThread, K1>,
Sequence<1, M1PerThreadM11>,
Sequence<1, N1PerThreadN11>>{};
constexpr auto threadwise_contraction =
ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
CThreadDesc_BM0_BM11_BN0_BN11,
Sequence<BK0PerThread, BK1>,
Sequence<1, BM1PerThreadBM11>,
Sequence<1, BN1PerThreadBN11>>{};
// read A_sub_0
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// read B_sub_0
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// read B_sub_1
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
// loop over rest of k
static_for<KPerThread, K0, KPerThread>{}([&](auto k) {
// loop over rest of bk0
static_for<BK0PerThread, BK0, BK0PerThread>{}([&](auto bk0) {
// read A_sub_0
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(k, I0, I0, I0),
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, I0, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// read B_sub_0
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(k, I0, I0, I0),
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, I0, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
// read B_sub_1
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
make_tuple(k, I1, I0, I0),
b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_,
make_tuple(bk0, I1, I0, I0),
b_block_buf,
b_k0_n0_n1_k1_thread_desc_,
b_thread_desc_bk0_bn0_bn1_bk1_,
make_tuple(I0, I1, I0, I0),
b_thread_buf);
// read A_sub_1
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
make_tuple(k, I1, I0, I0),
a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_,
make_tuple(bk0, I1, I0, I0),
a_block_buf,
a_k0_m0_m1_k1_thread_desc_,
a_thread_desc_bk0_bm0_bm1_bk1_,
make_tuple(I0, I1, I0, I0),
a_thread_buf);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I0, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I1, I0));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I0, I0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
threadwise_contraction.Run(a_thread_buf,
make_tuple(I0, I1, I0, I0),
b_thread_buf,
make_tuple(I0, I1, I0, I0),
c_thread_buf,
make_tuple(I1, I0, I1, I0));
}
private:
// A[K0, M0, M1, K1]
static constexpr auto a_k0_m0_m1_k1_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}, Number<K1>{}));
// A[BK0, BM0, BM1, BK1]
static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<BK0PerThread>{}, Number<BM0>{}, Number<BM1PerThreadBM11>{}, Number<BK1>{}));
// B[K0, N0, N1, K1]
static constexpr auto b_k0_n0_n1_k1_thread_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}, Number<K1>{}));
// B[BK0, BN0, BN1, BK1]
static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number<BK0PerThread>{}, Number<BN0>{}, Number<BN1PerThreadBN11>{}, Number<BK1>{}));
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
FloatA,
FloatA,
decltype(a_k0_m0_m1_k1_block_desc_),
decltype(a_k0_m0_m1_k1_thread_desc_),
Sequence<KPerThread, 1, M1PerThreadM11, K1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, M1PerThreadM11, K1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
decltype(a_block_desc_bk0_bm0_bm1_bk1_),
decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
Sequence<BK0PerThread, 1, BM1PerThreadBM11, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
FloatB,
FloatB,
decltype(b_k0_n0_n1_k1_block_desc_),
decltype(b_k0_n0_n1_k1_thread_desc_),
Sequence<KPerThread, 1, N1PerThreadN11, K1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, N1PerThreadN11, K1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
Sequence<0, 1, 2, 3>, // DimAccessOrder
Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
CIndex c_thread_origin_data_idx_;

View File

@@ -1,681 +0,0 @@
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
template <typename GridwiseContraction,
typename FloatAB,
typename FloatC,
typename AGKGM0GM10GM11GridDesc,
typename BGKGN0GN10GN11GridDesc,
typename CGM10BM0BM1GN10BN0BN1GridDesc,
typename CBlockIdToGM10GN10BlockClusterAdaptor,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_contraction_v1r1(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGKGM0GM10GM11GridDesc a_gk_gm0_gm10_gm11_grid_desc,
const BGKGN0GN10GN11GridDesc b_gk_gn0_gn10_gn11_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor
c_blockid_to_gm10_gn10_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGKGM0GM1GridDesc,
typename BGKGN0GN1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
typename ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_GM11,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_GN11,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGridIteratorHacks,
typename BGridIteratorHacks,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_GM11>{},
Number<BBlockTransferDstScalarPerVector_GN11>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk_gm0_gm10_gm11_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk_gn0_gn10_gn11_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc,
const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
GM0 == a_gk_gm0_gm1_grid_desc.GetLength(I1) &&
GM1 == a_gk_gm0_gm1_grid_desc.GetLength(I2) &&
GN0 == b_gk_gn0_gn1_grid_desc.GetLength(I1) &&
GN1 == b_gk_gn0_gn1_grid_desc.GetLength(I2) &&
GK == b_gk_gn0_gn1_grid_desc.GetLength(I0)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK % KPerBlock == 0));
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr index_t GM11 = GM1PerBlockGM11;
constexpr index_t GN11 = GN1PerBlockGN11;
const index_t GM10 = GM1 / GM11;
const index_t GN10 = GN1 / GN11;
const index_t grid_size = GM10 * GN10;
return grid_size;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK)
{
const bool has_main_k_block_loop = (GK + KPerBlock) / (2 * KPerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK)
{
const bool has_double_tail_k_block_loop = (GK / KPerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGKGM0GM10GM11GridDescriptor(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc)
{
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2);
const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11;
const auto a_gk_gm0_gm10_gm11_grid_desc = transform_dynamic_tensor_descriptor(
a_gk_gm0_gm1_grid_desc,
make_tuple(make_pass_through_transform(GK),
make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
return a_gk_gm0_gm10_gm11_grid_desc;
}
__host__ __device__ static constexpr auto
MakeBGKGN0GN10GN11GridDescriptor(const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc)
{
const auto GK = b_gk_gn0_gn1_grid_desc.GetLength(I0);
const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2);
const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11;
const auto b_gk_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
b_gk_gn0_gn1_grid_desc,
make_tuple(make_pass_through_transform(GK),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
return b_gk_gn0_gn10_gn11_grid_desc;
}
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
constexpr auto BM = GM0 * GM11;
constexpr auto BN = GN0 * GN11;
constexpr auto BM1 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
constexpr auto BN1 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
c_gm0_gm1_gn0_gn1_grid_desc,
make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor(
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_merge_transform(make_tuple(GM0, GM11)),
make_pass_through_transform(GN10),
make_merge_transform(make_tuple(GN0, GN11))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)),
make_pass_through_transform(GN10),
make_unmerge_transform(make_tuple(BN0, BN1))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
}
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
}
using AGKGM0GM10GM11GridDesc = decltype(MakeAGKGM0GM10GM11GridDescriptor(AGKGM0GM1GridDesc{}));
using BGKGN0GN10GN11GridDesc = decltype(MakeBGKGN0GN10GN11GridDescriptor(BGKGN0GN1GridDesc{}));
using CGM10BM0BM1GN10BN0BN1GridDesc =
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGKGM0GM10GM11GridDesc& a_gk_gm0_gm10_gm11_grid_desc,
const BGKGN0GN10GN11GridDesc& b_gk_gn0_gn10_gn11_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_gk_gm0_gm10_gm11_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_gk_gn0_gn10_gn11_grid_desc.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
const auto GK = a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0);
// divide block work by [GM10, GN10]
const auto c_gm10_gn10_block_cluster_idx =
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
// lds max alignment
// part of them should be moved into blockwise-gemm
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_GM11>{},
Number<BBlockTransferDstScalarPerVector_GN11>{},
Number<M1PerThreadM111>{},
Number<N1PerThreadN111>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk_gm0_gm10_gm11_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk_gn0_gn10_gn11_block_desc =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}), max_lds_align);
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto a_gk_bm_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}), max_lds_align);
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto b_gk_bn_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11>,
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_gk_gm0_gm10_gm11_grid_desc),
decltype(a_gk_gm0_gm10_gm11_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_GM11,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(a_gk_gm0_gm10_gm11_grid_desc,
make_multi_index(0, 0, igm10, 0),
a_gk_gm0_gm10_gm11_block_desc,
make_multi_index(0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11>,
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_gk_gn0_gn10_gn11_grid_desc),
decltype(b_gk_gn0_gn10_gn11_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_GN11,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_gk_gn0_gn10_gn11_grid_desc,
make_multi_index(0, 0, ign10, 0),
b_gk_gn0_gn10_gn11_block_desc,
make_multi_index(0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_gk_bm_block_desc),
decltype(b_gk_bn_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{};
constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack =
AGridMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
BGridMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(
a_gk_gm0_gm10_gm11_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_gk_gn0_gn10_gn11_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(
a_gk_gm0_gm10_gm11_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(
b_gk_gn0_gn10_gn11_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < GK - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_gk_gm0_gm10_gm11_grid_desc,
a_block_slice_copy_step,
a_k_m0_m1_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_gk_gn0_gn10_gn11_grid_desc,
b_block_slice_copy_step,
b_k_n0_n1_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = GM1PerBlockGM11 / M11;
constexpr index_t N10 = GN1PerBlockGN11 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
Sequence<1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
make_multi_index(igm10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
ign10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_grid_buf,
CGridIteratorHacks{});
}
}
};
} // namespace ck
#endif

View File

@@ -5,8 +5,8 @@
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2r2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_gemm_v2r3.hpp"
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
@@ -15,10 +15,10 @@ namespace ck {
template <typename GridwiseContraction,
typename FloatAB,
typename FloatC,
typename AGK0GM0GM10GM11GK1GridDesc,
typename BGK0GN0GN10GN11GK1GridDesc,
typename CGM10BM0BM1GN10BN0BN1GridDesc,
typename CBlockIdToGM10GN10BlockClusterAdaptor,
typename AGridDesc_GK0_GM0_GM10_GM11_GK1,
typename BGridDesc_GK0_GN0_GN10_GN11_GK1,
typename CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1,
typename CGridBlockCluster_BlockId_To_GM10_GN10,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
@@ -29,11 +29,10 @@ __global__ void
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const AGK0GM0GM10GM11GK1GridDesc a_gk0_gm0_gm10_gm11_gk1_grid_desc,
const BGK0GN0GN10GN11GK1GridDesc b_gk0_gn0_gn10_gn11_gk1_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor
c_blockid_to_gm10_gn10_block_cluster_adaptor)
const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10)
{
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
@@ -44,10 +43,10 @@ __global__ void
p_b_grid,
p_c_grid,
p_shared_block,
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
@@ -57,19 +56,19 @@ template <index_t BlockSize,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGK0GM0GM1GK1GridDesc,
typename BGK0GN0GN1GK1GridDesc,
typename CGM0GM1GN0GN1GridDesc,
typename AGridDesc_GK0_GM0_GM1_GK1,
typename BGridDesc_GK0_GN0_GN1_GK1,
typename CGridDesc_GM0_GM1_GN0_GN1,
index_t GM1PerBlockGM11,
index_t GN1PerBlockGN11,
index_t KPerBlock,
index_t M1PerThreadM111,
index_t N1PerThreadN111,
index_t KPerThread,
index_t M11N11ThreadClusterM1100,
index_t M11N11ThreadClusterN1100,
index_t M11N11ThreadClusterM1101,
index_t M11N11ThreadClusterN1101,
index_t GK0PerBlock,
index_t BM1PerThreadBM11,
index_t BN1PerThreadBN11,
index_t BK0PerThread,
index_t BM10BN10ThreadClusterBM100,
index_t BM10BN10ThreadClusterBN100,
index_t BM10BN10ThreadClusterBM101,
index_t BM10BN10ThreadClusterBN101,
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
typename ABlockTransferThreadClusterArrangeOrder,
@@ -92,7 +91,7 @@ template <index_t BlockSize,
typename CGridIteratorHacks,
typename AGridMoveSliceWindowIteratorHacks,
typename BGridMoveSliceWindowIteratorHacks>
struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
@@ -100,9 +99,9 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
static constexpr auto I3 = Number<3>{};
// GM0 and GN0 need to known at compile-time
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
static constexpr auto GK1 = AGK0GM0GM1GK1GridDesc{}.GetLength(I3);
static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0);
static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2);
static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3);
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
@@ -113,61 +112,62 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
}
__host__ __device__ static constexpr bool
CheckValidity(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
"wrong! GM0 and GN0 need to be known at compile-time");
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
GM0 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I1) &&
GM1 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2) &&
GN0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I1) &&
GN1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2) &&
GK0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0) &&
GK1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I3)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % KPerBlock == 0));
return (
(GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) &&
GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) &&
GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) &&
GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) &&
GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) &&
GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) &&
GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) &&
GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) &&
GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) &&
GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) &&
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0));
}
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr index_t GM11 = GM1PerBlockGM11;
constexpr index_t GN11 = GN1PerBlockGN11;
@@ -182,29 +182,29 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
{
const bool has_main_k_block_loop = (GK0 + KPerBlock) / (2 * KPerBlock) > 1;
const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1;
return has_main_k_block_loop;
}
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
{
const bool has_double_tail_k_block_loop = (GK0 / KPerBlock) % 2 == 0;
const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0;
return has_double_tail_k_block_loop;
}
__host__ __device__ static constexpr auto
MakeAGK0GM0GM10GM11GK1GridDescriptor(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc)
__host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1)
{
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2);
const auto GM11 = Number<GM1PerBlockGM11>{};
const auto GM10 = GM1 / GM11;
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
a_gk0_gm0_gm1_gk1_grid_desc,
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor(
a_grid_desc_gk0_gm0_gm1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
@@ -212,20 +212,20 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return a_gk0_gm0_gm10_gm11_gk1_grid_desc;
return a_grid_desc_gk0_gm0_gm10_gm11_gk1;
}
__host__ __device__ static constexpr auto
MakeBGK0GN0GN10GN11GK1GridDescriptor(const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc)
__host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1)
{
const auto GK0 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0);
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0);
const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2);
const auto GN11 = Number<GN1PerBlockGN11>{};
const auto GN10 = GN1 / GN11;
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
b_gk0_gn0_gn1_gk1_grid_desc,
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor(
b_grid_desc_gk0_gn0_gn1_gk1,
make_tuple(make_pass_through_transform(GK0),
make_pass_through_transform(GN0),
make_unmerge_transform(make_tuple(GN10, GN11)),
@@ -233,14 +233,14 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
return b_gk0_gn0_gn10_gn11_gk1_grid_desc;
return b_grid_desc_gk0_gn0_gn10_gn11_gk1;
}
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
__host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
@@ -252,15 +252,15 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
constexpr auto BN = GN0 * GN11;
constexpr auto BM1 =
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
Number<BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11>{};
constexpr auto BN1 =
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
Number<BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11>{};
constexpr auto BM0 = BM / BM1;
constexpr auto BN0 = BN / BN1;
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
c_gm0_gm1_gn0_gn1_grid_desc,
c_grid_desc_gm0_gm1_gn0_gn1,
make_tuple(make_pass_through_transform(GM0),
make_unmerge_transform(make_tuple(GM10, GM11)),
make_pass_through_transform(GN0),
@@ -277,7 +277,7 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor(
c_gm10_bm_gn10_bn_grid_desc,
make_tuple(make_pass_through_transform(GM10),
make_unmerge_transform(make_tuple(BM0, BM1)),
@@ -286,14 +286,14 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
}
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
__host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10(
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1)
{
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
@@ -301,22 +301,22 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
const auto GM10 = GM1 / GM11;
const auto GN10 = GN1 / GN11;
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
return c_grid_block_cluster_blockid_to_gm10_gn10;
}
using AGK0GM0GM10GM11GK1GridDesc =
decltype(MakeAGK0GM0GM10GM11GK1GridDescriptor(AGK0GM0GM1GK1GridDesc{}));
using BGK0GN0GN10GN11GK1GridDesc =
decltype(MakeBGK0GN0GN10GN11GK1GridDescriptor(BGK0GN0GN1GK1GridDesc{}));
using CGM10BM0BM1GN10BN0BN1GridDesc =
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{}));
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{}));
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{}));
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{}));
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
@@ -324,25 +324,25 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
const AGK0GM0GM10GM11GK1GridDesc& a_gk0_gm0_gm10_gm11_gk1_grid_desc,
const BGK0GN0GN10GN11GK1GridDesc& b_gk0_gn0_gn10_gn11_gk1_grid_desc,
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_grid, a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetElementSpaceSize());
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_grid, b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetElementSpaceSize());
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
const auto GK0 = a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0);
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
// divide block work by [GM10, GN10]
const auto c_gm10_gn10_block_cluster_idx =
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex(
make_multi_index(get_block_1d_id()));
// HACK: this force index data into SGPR
@@ -356,46 +356,46 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 =
make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
make_tuple(Number<GK0PerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
max_lds_align);
// A matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto a_gk0_bm_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
// B matrix in LDS memory for blockwise GEMM
// be careful of LDS alignment
constexpr auto b_gk0_bn_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<GK0PerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
static_assert(a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize() ==
a_gk0_bm_gk1_block_desc.GetElementSpaceSize() &&
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize() ==
b_gk0_bn_gk1_block_desc.GetElementSpaceSize(),
static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() ==
a_block_desc_gk0_bm_gk1.GetElementSpaceSize() &&
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() ==
b_block_desc_gk0_bn_gk1.GetElementSpaceSize(),
"wrong!");
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc),
decltype(a_gk0_gm0_gm10_gm11_gk1_block_desc),
decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1),
decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
@@ -403,23 +403,23 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, igm10, 0, 0),
a_gk0_gm0_gm10_gm11_gk1_block_desc,
a_block_desc_gk0_gm0_gm10_gm11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc),
decltype(b_gk0_gn0_gn10_gn11_gk1_block_desc),
decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1),
decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3, 4>,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
@@ -427,102 +427,103 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
false,
true>(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, ign10, 0, 0),
b_gk0_gn0_gn10_gn11_gk1_block_desc,
b_block_desc_gk0_gn0_gn10_gn11_gk1,
make_multi_index(0, 0, 0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
// a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_gk0_bm_gk1_block_desc),
decltype(b_gk0_bn_gk1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_block_desc_gk0_bm_gk1),
decltype(b_block_desc_gk0_bn_gk1),
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM100,
BM10BN10ThreadClusterBN100,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
BM1PerThreadBM11,
BN1PerThreadBN11>{};
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 =
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 =
make_dynamic_naive_tensor_descriptor_packed_v2(
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1));
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
decltype(c_thread_desc_bm0_bm1_bn0_bn1),
decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{}
.Run(c_thread_desc_bm0_bm1_bn0_bn1,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_aligned_space_size,
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_aligned_space_size,
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
index_t gk0_block_on_grid = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
@@ -530,25 +531,25 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1,
a_block_even_buf,
b_block_even_buf,
c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
@@ -556,29 +557,29 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < GK0 - 2 * KPerBlock);
gk0_block_on_grid += 2 * GK0PerBlock;
} while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1,
a_block_slice_copy_step,
AGridMoveSliceWindowIteratorHacks{});
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1,
b_block_slice_copy_step,
BGridMoveSliceWindowIteratorHacks{});
@@ -586,23 +587,23 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{});
b_blockwise_copy.RunRead(
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{});
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
@@ -610,61 +611,51 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr index_t M11 =
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
constexpr index_t N11 =
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
constexpr index_t M10 = GM1PerBlockGM11 / M11;
constexpr index_t N10 = GN1PerBlockGN11 / N11;
constexpr index_t M111 = M1PerThreadM111;
constexpr index_t N111 = N1PerThreadN111;
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1]>{},
I1,
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2]>{},
Number<c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>{}));
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 =
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1),
decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1),
Sequence<1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1],
1,
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2],
c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
make_multi_index(igm10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
ign10,
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_multi_index(igm10,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1],
ign10,
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2],
c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])}
.Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1,
make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_buf,
CGridIteratorHacks{});
}

View File

@@ -1,552 +0,0 @@
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_set.hpp"
namespace ck {
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc a_k_m_global_desc,
const BGlobalDesc b_k_n_global_desc,
const CGlobalDesc c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc c_block_cluster_desc)
{
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
// pass tensor descriptor by __CONSTANT__ void pointer
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
// non-modifiable parameter address space, so compiler can enable corresponding optimization
template <typename GridwiseGemm,
typename FloatA,
typename FloatB,
typename FloatC,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global,
const FloatB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const void __CONSTANT__* p_a_k_m_global_desc,
const void __CONSTANT__* p_b_k_n_global_desc,
const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc,
const void __CONSTANT__* p_c_block_cluster_desc)
{
// first cast void __CONSTANT__ void* to void*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
const auto a_k_m_global_desc =
*reinterpret_cast<const AGlobalDesc*>((const void*)p_a_k_m_global_desc);
const auto b_k_n_global_desc =
*reinterpret_cast<const BGlobalDesc*>((const void*)p_b_k_n_global_desc);
const auto c_m0_m1_n0_n1_global_desc =
*reinterpret_cast<const CGlobalDesc*>((const void*)p_c_m0_m1_n0_n1_global_desc);
const auto c_block_cluster_desc =
*reinterpret_cast<const CBlockClusterDesc*>((const void*)p_c_block_cluster_desc);
GridwiseGemm::Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
#endif
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperation CGlobalMemoryDataOperation,
typename AGlobalDesc,
typename BGlobalDesc,
typename CGlobalDesc,
typename CBlockClusterDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t M1PerThread,
index_t N1PerThread,
index_t KPerThread,
index_t M1N1ThreadClusterM10,
index_t M1N1ThreadClusterN10,
index_t M1N1ThreadClusterM11,
index_t M1N1ThreadClusterN11,
typename ABlockTransferThreadSliceLengths_K_M,
typename ABlockTransferThreadClusterLengths_K_M,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_M,
bool AThreadTransferSrcResetCoordinateAfterRun,
typename BBlockTransferThreadSliceLengths_K_N,
typename BBlockTransferThreadClusterLengths_K_N,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_N,
bool BThreadTransferSrcResetCoordinateAfterRun,
typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector,
typename AGlobalIteratorHacks,
typename BGlobalIteratorHacks,
typename CGlobalIteratorHacks,
typename AGlobalMoveSliceWindowIteratorHacks,
typename BGlobalMoveSliceWindowIteratorHacks>
struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1
{
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<M1PerThread>{},
Number<N1PerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB);
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
FloatAB* __restrict__ p_shared_block,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_a_global, a_k_m_global_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_b_global, b_k_n_global_desc.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize());
const auto K = a_k_m_global_desc.GetLength(I0);
const auto M = a_k_m_global_desc.GetLength(I1);
const auto N = b_k_n_global_desc.GetLength(I1);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_global into SGPR
const index_t m_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_global =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// lds max alignment
constexpr auto max_lds_align = math::lcm(Number<ABlockTransferDstScalarPerVector_M>{},
Number<BBlockTransferDstScalarPerVector_N>{},
Number<M1PerThread>{},
Number<N1PerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), max_lds_align);
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), max_lds_align);
// A matrix blockwise copy
auto a_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, MPerBlock>,
ABlockTransferThreadSliceLengths_K_M,
ABlockTransferThreadClusterLengths_K_M,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
ABlockTransferSrcAccessOrder,
Sequence<0, 1>,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_M,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true>(
a_k_m_global_desc,
make_multi_index(0, m_block_data_idx_on_global),
a_k_m_block_desc,
make_multi_index(0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
InMemoryDataOperation::Set,
Sequence<KPerBlock, NPerBlock>,
BBlockTransferThreadSliceLengths_K_N,
BBlockTransferThreadClusterLengths_K_N,
BBlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
BBlockTransferSrcAccessOrder,
Sequence<0, 1>,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_N,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true>(
b_k_n_global_desc,
make_multi_index(0, n_block_data_idx_on_global),
b_k_n_block_desc,
make_multi_index(0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert(
MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 &&
NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0,
"wrong!");
constexpr index_t M0PerThread =
MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10);
constexpr index_t N0PerThread =
NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10);
constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
a_k_m_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<M0PerThread>{},
Number<M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
b_k_n_block_desc,
make_tuple(make_pass_through_transform(Number<KPerBlock>{}),
make_unmerge_transform(make_tuple(
Number<N0PerThread>{},
Number<N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<M0PerThread>{},
Number<M1PerThread>{},
Number<N0PerThread>{},
Number<N1PerThread>{}));
const auto blockwise_gemm =
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k_m0_m1_block_desc),
decltype(b_k_n0_n1_block_desc),
decltype(c_m0_m1_n0_n1_thread_desc),
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
M1PerThread,
N1PerThread>{};
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
ThreadwiseDynamicTensorSliceSet_v1<
FloatAcc,
decltype(c_m0_m1_n0_n1_thread_desc),
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>>{}
.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{};
constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
constexpr auto a_k_m_global_move_slice_window_iterator_hack =
AGlobalMoveSliceWindowIteratorHacks{};
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
BGlobalMoveSliceWindowIteratorHacks{};
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double, b_k_n_block_desc.GetElementSpaceSize());
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize());
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
}
if constexpr(HasMainKBlockLoop)
{
index_t k_block_data_begin = 0;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(
a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(
b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf);
k_block_data_begin += 2 * KPerBlock;
} while(k_block_data_begin < K - 2 * KPerBlock);
}
// LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{
a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc,
a_block_slice_copy_step,
a_k_m_global_move_slice_window_iterator_hack);
b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc,
b_block_slice_copy_step,
b_k_n_global_move_slice_window_iterator_hack);
__syncthreads();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks);
b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf);
b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf);
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on last data
blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf);
}
// output: register to global memory
{
constexpr auto M1 = Number<M1PerThread * M1N1ThreadClusterM10 * M1N1ThreadClusterM11>{};
constexpr auto N1 = Number<N1PerThread * M1N1ThreadClusterN10 * M1N1ThreadClusterN11>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const auto c_thread_data_idx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
ThreadwiseDynamicTensorSliceTransfer_v1r3<
FloatAcc,
FloatC,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<M0PerThread, M1PerThread, N0PerThread, N1PerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>{
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
c_thread_data_idx_on_block[I1],
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
c_thread_data_idx_on_block[I3])}
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple(I0, I0, I0, I0),
c_thread_buf,
c_m0_m1_n0_n1_global_desc,
c_global_buf,
c_m0_m1_n0_n1_global_tensor_iterator_hacks);
}
}
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global,
FloatC* __restrict__ p_c_global,
const AGlobalDesc& a_k_m_global_desc,
const BGlobalDesc& b_k_n_global_desc,
const CGlobalDesc& c_m0_m1_n0_n1_global_desc,
const CBlockClusterDesc& c_block_cluster_desc,
integral_constant<bool, HasMainKBlockLoop>,
integral_constant<bool, HasDoubleTailKBlockLoop>)
{
constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
Run(p_a_global,
p_b_global,
p_c_global,
a_k_m_global_desc,
b_k_n_global_desc,
c_m0_m1_n0_n1_global_desc,
c_block_cluster_desc,
p_shared_block,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
}
};
} // namespace ck
#endif

View File

@@ -435,21 +435,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
// register
const auto blockwise_gemm =
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize,
FloatAB,
FloatAB,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc),
M1PerThreadM111,
N1PerThreadN111,
KPerThread,
M11N11ThreadClusterM1100,
M11N11ThreadClusterN1100,
M11N11ThreadClusterM1101,
M11N11ThreadClusterN1101,
M1PerThreadM111,
N1PerThreadN111>{};
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();

View File

@@ -1,40 +1,44 @@
#ifndef CK_THREADWISE_GEMM_V2_HPP
#define CK_THREADWISE_GEMM_V2_HPP
#ifndef CK_THREADWISE_CONTRACTION_HPP
#define CK_THREADWISE_CONTRACTION_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1]
// Tensor element can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
// known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2,
static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
@@ -70,28 +74,31 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto K = KLengths{}[I0];
constexpr auto M0 = MLengths{}[I0];
constexpr auto M1 = MLengths{}[I1];
constexpr auto N0 = NLengths{}[I0];
constexpr auto N1 = NLengths{}[I1];
constexpr auto TK = TKLengths{}[I0];
constexpr auto TM0 = TMLengths{}[I0];
constexpr auto TM1 = TMLengths{}[I1];
constexpr auto TN0 = TNLengths{}[I0];
constexpr auto TN1 = TNLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M0, 1>{}([&](auto m0) {
static_for<0, M1, 1>{}([&](auto m1) {
static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) {
static_for<0, TK, 1>{}([&](auto tk) {
static_for<0, TM0, 1>{}([&](auto tm0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN0, 1>{}([&](auto tn0) {
static_for<0, TN1, 1>{}([&](auto tn1) {
constexpr index_t a_offset =
ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1));
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk, tm0, tm1));
constexpr index_t b_offset =
BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1));
constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1));
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk, tn0, tn1));
constexpr index_t c_offset =
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<FloatA, FloatB, FloatC>(
a_buf[Number<a_offset>{}],
@@ -105,35 +112,39 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
}
};
// C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1]
// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1]
// Tensor element can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
// known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template <typename FloatA,
typename FloatB,
typename FloatC,
typename ADesc,
typename BDesc,
typename CDesc,
typename KLengths,
typename MLengths,
typename NLengths,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
__device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1()
__device__ constexpr ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
// CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
// TODO remove this restriction
static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2,
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
@@ -169,43 +180,45 @@ struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t K0 = KLengths{}[I0];
constexpr index_t K1 = KLengths{}[I1];
constexpr index_t M0 = MLengths{}[I0];
constexpr index_t M1 = MLengths{}[I1];
constexpr index_t N0 = NLengths{}[I0];
constexpr index_t N1 = NLengths{}[I1];
constexpr index_t TK0 = TKLengths{}[I0];
constexpr index_t TK1 = TKLengths{}[I1];
constexpr index_t TM0 = TMLengths{}[I0];
constexpr index_t TM1 = TMLengths{}[I1];
constexpr index_t TN0 = TNLengths{}[I0];
constexpr index_t TN1 = TNLengths{}[I1];
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, K0, 1>{}([&](auto k0) {
static_for<0, M0, 1>{}([&](auto m0) {
static_for<0, M1, 1>{}([&](auto m1) {
static_for<0, N0, 1>{}([&](auto n0) {
static_for<0, N1, 1>{}([&](auto n1) {
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM0, 1>{}([&](auto tm0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN0, 1>{}([&](auto tn0) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, K1> a_vec;
vector_type<FloatB, K1> b_vec;
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, K1, 1>{}([&](auto k1) {
constexpr index_t a_offset = ADesc{}.CalculateOffset(
a_origin_idx + make_multi_index(k0, m0, m1, k1));
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t a_offset =
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
constexpr index_t b_offset = BDesc{}.CalculateOffset(
b_origin_idx + make_multi_index(k0, n0, n1, k1));
constexpr index_t b_offset =
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
a_vec.template AsType<FloatA>()(k1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(k1) = b_buf[Number<b_offset>{}];
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, K1>::type;
using b_vector_t = typename vector_type<FloatB, K1>::type;
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CDesc{}.CalculateOffset(
c_origin_idx + make_multi_index(m0, m1, n0, n1));
constexpr index_t c_offset =
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
a_vec.template AsType<a_vector_t>()[I0],

View File

@@ -1,379 +0,0 @@
#include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r1.hpp"
#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp"
using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr index_t N0 = CK_PARAM_N0;
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
constexpr index_t KPerBlock = CK_PARAM_KPerBlock;
constexpr index_t M1PerThread = CK_PARAM_M1PerThread;
constexpr index_t N1PerThread = CK_PARAM_N1PerThread;
constexpr index_t KPerThread = CK_PARAM_KPerThread;
constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10;
constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10;
constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11;
constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11;
using ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11>;
using ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11>;
using ABlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_ABlockTransferThreadClusterArrangeOrder>;
using ABlockTransferSrcAccessOrder = Sequence<CK_PARAM_ABlockTransferSrcAccessOrder>;
constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim;
constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector;
constexpr index_t ABlockTransferDstScalarPerVector_GM11 =
CK_PARAM_ABlockTransferDstScalarPerVector_GM11;
constexpr bool AThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun);
using BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11>;
using BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11>;
using BBlockTransferThreadClusterArrangeOrder =
Sequence<CK_PARAM_BBlockTransferThreadClusterArrangeOrder>;
using BBlockTransferSrcAccessOrder = Sequence<CK_PARAM_BBlockTransferSrcAccessOrder>;
constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim;
constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector;
constexpr index_t BBlockTransferDstScalarPerVector_GN11 =
CK_PARAM_BBlockTransferDstScalarPerVector_GN11;
constexpr bool BThreadTransferSrcResetCoordinateAfterRun =
static_cast<bool>(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun);
using CThreadTransferSrcDstAccessOrder = Sequence<CK_PARAM_CThreadTransferSrcDstAccessOrder>;
constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_prepare(
int n,
int c,
int hi,
int wi,
int k,
int y,
int x,
int convStrideH,
int convStrideW,
int convDilationY,
int convDilationX,
int leftPadH,
int leftPadW,
int rightPadH,
int rightPadW,
void* p_a_gk_gm0_gm10_gm11_grid_desc,
void* p_b_gk_gn0_gn10_gn11_grid_desc,
void* p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
void* p_c_blockid_to_gm10_gn10_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1;
const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1;
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi));
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x));
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo));
const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(convStrideH, convStrideW),
make_tuple(convDilationY, convDilationX),
make_tuple(leftPadH, leftPadW),
make_tuple(rightPadH, rightPadW));
const auto a_gk_gm0_gm1_grid_desc = descs[I0];
const auto b_gk_gn0_gn1_grid_desc = descs[I1];
const auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
using AGKGM0GM1GridDesc = decltype(a_gk_gm0_gm1_grid_desc);
using BGKGN0GN1GridDesc = decltype(b_gk_gn0_gn1_grid_desc);
using CGM0GM1GN0GN1GridDesc = decltype(c_gm0_gm1_gn0_gn1_grid_desc);
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{})));
using BGridIteratorHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>;
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set, /* ToDo tunable */
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_GM11,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_GN11,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
auto a_gk_gm0_gm10_gm11_grid_desc =
GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc);
auto b_gk_gn0_gn10_gn11_grid_desc =
GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc);
auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
if(hipThreadIdx_x == 0)
{
*static_cast<decltype(a_gk_gm0_gm10_gm11_grid_desc)*>(p_a_gk_gm0_gm10_gm11_grid_desc) =
a_gk_gm0_gm10_gm11_grid_desc;
*static_cast<decltype(b_gk_gn0_gn10_gn11_grid_desc)*>(p_b_gk_gn0_gn10_gn11_grid_desc) =
b_gk_gn0_gn10_gn11_grid_desc;
*static_cast<decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc)*>(
p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc) = c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
*static_cast<decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor)*>(
p_c_blockid_to_gm10_gn10_block_cluster_adaptor) =
c_blockid_to_gm10_gn10_block_cluster_adaptor;
};
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_gk_gm0_gm10_gm11_grid_desc,
const void __CONSTANT__* p_b_gk_gn0_gn10_gn11_grid_desc,
const void __CONSTANT__* p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
const void __CONSTANT__* p_c_blockid_to_gm10_gn10_block_cluster_adaptor)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1));
constexpr auto a_gk_gm0_gm1_grid_desc = descs[I0];
constexpr auto b_gk_gn0_gn1_grid_desc = descs[I1];
constexpr auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
using AGKGM0GM1GridDesc = decltype(a_gk_gm0_gm1_grid_desc);
using BGKGN0GN1GridDesc = decltype(b_gk_gn0_gn1_grid_desc);
using CGM0GM1GN0GN1GridDesc = decltype(c_gm0_gm1_gn0_gn1_grid_desc);
using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{})));
using BGridIteratorHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})));
using CGridIteratorHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})));
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>;
using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set, /* ToDo tunable */
AGKGM0GM1GridDesc,
BGKGN0GN1GridDesc,
CGM0GM1GN0GN1GridDesc,
GM1PerBlockGM11,
GN1PerBlockGN11,
KPerBlock,
M1PerThread,
N1PerThread,
KPerThread,
M1N1ThreadClusterM10,
M1N1ThreadClusterN10,
M1N1ThreadClusterM11,
M1N1ThreadClusterN11,
ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_GM11,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_GN11,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
using AGKGM0GM10GM11GridDesc =
decltype(GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc));
using BGKGN0GN10GN11GridDesc =
decltype(GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc));
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc));
using CBlockIdToGM10GN10BlockClusterAdaptor =
decltype(GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(
c_gm0_gm1_gn0_gn1_grid_desc));
const auto a_gk_gm0_gm10_gm11_grid_desc = *reinterpret_cast<const AGKGM0GM10GM11GridDesc*>(
(const void*)p_a_gk_gm0_gm10_gm11_grid_desc);
const auto b_gk_gn0_gn10_gn11_grid_desc = *reinterpret_cast<const BGKGN0GN10GN11GridDesc*>(
(const void*)p_b_gk_gn0_gn10_gn11_grid_desc);
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
*reinterpret_cast<const CGM10BM0BM1GN10BN0BN1GridDesc*>(
(const void*)p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToGM10GN10BlockClusterAdaptor*>(
(const void*)p_c_blockid_to_gm10_gn10_block_cluster_adaptor);
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_gk_gm0_gm10_gm11_grid_desc,
b_gk_gn0_gn10_gn11_grid_desc,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
c_blockid_to_gm10_gn10_block_cluster_adaptor,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
};

View File

@@ -0,0 +1,402 @@
#include "common_header.hpp"
#include "type_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "gridwise_dynamic_contraction_v1r2.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
using namespace ck;
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_ACC_DATATYPE)>::type;
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BlockSize;
constexpr auto GN0 = Number<CK_PARAM_GN0>{};
constexpr auto GK1 = Number<CK_PARAM_GK1>{};
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
constexpr index_t BM10BN10ThreadClusterBM100 = CK_PARAM_BM10BN10ThreadClusterBM100;
constexpr index_t BM10BN10ThreadClusterBN100 = CK_PARAM_BM10BN10ThreadClusterBN100;
constexpr index_t BM10BN10ThreadClusterBM101 = CK_PARAM_BM10BN10ThreadClusterBM101;
constexpr index_t BM10BN10ThreadClusterBN101 = CK_PARAM_BM10BN10ThreadClusterBN101;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>;
using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 =
Sequence<CK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1>;
using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>;
using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 =
Sequence<CK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1>;
using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>;
using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>;
constexpr index_t CThreadTransferSrcDstVectorDim = 5;
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare(
index_t N,
index_t C,
index_t Hi,
index_t Wi,
index_t K,
index_t Y,
index_t X,
index_t ConvStrideH,
index_t ConvStrideW,
index_t ConvDilationH,
index_t ConvDilationW,
index_t InLeftPadH,
index_t InLeftPadW,
index_t InRightPadH,
index_t InRightPadW,
void* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1,
void* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
void* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
void* p_c_grid_block_cluster_blockid_to_gm10_gn10)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
const index_t Ho =
(Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1;
const index_t Wo =
(Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1;
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C, Hi, Wi));
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C, Y, X));
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(ConvStrideH, ConvStrideW),
make_tuple(ConvDilationH, ConvDilationW),
make_tuple(InLeftPadH, InLeftPadW),
make_tuple(InRightPadH, InRightPadW),
GN0,
GK1);
const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
using AGridIteratorHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using BGridIteratorHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using CGridIteratorHacks = decltype(make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM100,
BM10BN10ThreadClusterBN100,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1);
auto c_grid_block_cluster_blockid_to_gm10_gn10 =
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1);
if(hipThreadIdx_x == 0)
{
*static_cast<decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1)*>(
p_a_grid_desc_gk0_gm0_gm10_gm11_gk1) = a_grid_desc_gk0_gm0_gm10_gm11_gk1;
*static_cast<decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1)*>(
p_b_grid_desc_gk0_gn0_gn10_gn11_gk1) = b_grid_desc_gk0_gn0_gn10_gn11_gk1;
*static_cast<decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1)*>(
p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1) = c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
*static_cast<decltype(c_grid_block_cluster_blockid_to_gm10_gn10)*>(
p_c_grid_block_cluster_blockid_to_gm10_gn10) =
c_grid_block_cluster_blockid_to_gm10_gn10;
};
};
extern "C" __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const void __CONSTANT__* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1,
const void __CONSTANT__* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
const void __CONSTANT__* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
const void __CONSTANT__* p_c_grid_block_cluster_blockid_to_gm10_gn10)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
constexpr auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3));
constexpr auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
constexpr auto descs =
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
make_tuple(1, 1),
GN0,
GK1);
constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1);
using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1);
using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1);
using AGridIteratorHacks =
decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using BGridIteratorHacks = decltype(make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1
using CGridIteratorHacks = decltype(make_tuple(
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
make_tuple(
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1
using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>;
using BGridMoveSliceWindowIteratorHacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
using GridwiseContraction =
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
BlockSize,
FloatAB,
FloatAcc,
FloatC,
InMemoryDataOperation::Set,
AGridDesc_GK0_GM0_GM1_GK1,
BGridDesc_GK0_GN0_GN1_GK1,
CGridDesc_GM0_GM1_GN0_GN1,
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM100,
BM10BN10ThreadClusterBN100,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorContiguousDimOrder,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorContiguousDimOrder,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridIteratorHacks,
BGridIteratorHacks,
CGridIteratorHacks,
AGridMoveSliceWindowIteratorHacks,
BGridMoveSliceWindowIteratorHacks>;
using AGridDesc_GK0_GM0_GM10_GM11_GK1 =
decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
a_grid_desc_gk0_gm0_gm1_gk1));
using BGridDesc_GK0_GN0_GN10_GN11_GK1 =
decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
b_grid_desc_gk0_gn0_gn1_gk1));
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 =
decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
c_grid_desc_gm0_gm1_gn0_gn1));
using CGridBlockCluster_BlockId_To_GM10_GN10 =
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
c_grid_desc_gm0_gm1_gn0_gn1));
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
*reinterpret_cast<const AGridDesc_GK0_GM0_GM10_GM11_GK1*>(
(const void*)p_a_grid_desc_gk0_gm0_gm10_gm11_gk1);
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
*reinterpret_cast<const BGridDesc_GK0_GN0_GN10_GN11_GK1*>(
(const void*)p_b_grid_desc_gk0_gn0_gn10_gn11_gk1);
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
*reinterpret_cast<const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1*>(
(const void*)p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
*reinterpret_cast<const CGridBlockCluster_BlockId_To_GM10_GN10*>(
(const void*)p_c_grid_block_cluster_blockid_to_gm10_gn10);
constexpr index_t shared_block_size =
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
__shared__ FloatAB p_shared_block[shared_block_size];
GridwiseContraction::Run(p_a_grid,
p_b_grid,
p_c_grid,
p_shared_block,
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
c_grid_block_cluster_blockid_to_gm10_gn10,
integral_constant<bool, HasMainKBlockLoop>{},
integral_constant<bool, HasDoubleTailKBlockLoop>{});
};

View File

@@ -1,271 +0,0 @@
#ifndef CONV_TUNABLES_HPP
#define CONV_TUNABLES_HPP
#include "config.hpp"
struct tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw
{
ck::index_t BlockSize; // usually not tunable
ck::index_t MPerBlock;
ck::index_t NPerBlock;
ck::index_t KPerBlock;
ck::index_t M1PerThread;
ck::index_t N1PerThread;
ck::index_t KPerThread;
ck::index_t M1N1ThreadClusterM10;
ck::index_t M1N1ThreadClusterN10;
ck::index_t M1N1ThreadClusterM11;
ck::index_t M1N1ThreadClusterN11;
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K_M0_M1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K_M0_M1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
ck::index_t ABlockTransferSrcVectorDim;
ck::index_t ABlockTransferSrcScalarPerVector;
ck::index_t ABlockTransferDstScalarPerVector_M1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K_N0_N1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K_N0_N1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
ck::index_t BBlockTransferSrcVectorDim;
ck::index_t BBlockTransferSrcScalarPerVector;
ck::index_t BBlockTransferDstScalarPerVector_N1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 6> CThreadTransferSrcDstAccessOrder;
ck::index_t CThreadTransferSrcDstVectorDim;
ck::index_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw = {
256, 128, 128, 8, 4, 4, 1,
8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0},
{2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128},
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
5, 1};
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
{
ck::index_t BlockSize; // usually not tunable
ck::index_t MPerBlock;
ck::index_t NPerBlock;
ck::index_t KPerBlock;
ck::index_t MPerWave;
ck::index_t NPerWave;
ck::index_t K1;
ck::index_t MRepeat;
ck::index_t NRepeat;
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
ck::index_t ABlockTransferSrcVectorDim;
ck::index_t ABlockTransferSrcScalarPerVector;
ck::index_t ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
ck::index_t BBlockTransferSrcVectorDim;
ck::index_t BBlockTransferSrcScalarPerVector;
ck::index_t BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
ck::index_t CThreadTransferSrcDstVectorDim;
ck::index_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
{
ck::index_t BlockSize; // usually not tunable
ck::index_t MPerBlock;
ck::index_t NPerBlock;
ck::index_t KPerBlock;
ck::index_t MPerWave;
ck::index_t NPerWave;
ck::index_t K1;
ck::index_t MRepeat;
ck::index_t NRepeat;
std::array<ck::index_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<ck::index_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> ABlockTransferSrcAccessOrder;
ck::index_t ABlockTransferSrcVectorDim;
ck::index_t ABlockTransferSrcScalarPerVector;
ck::index_t ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<ck::index_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 3> BBlockTransferSrcAccessOrder;
ck::index_t BBlockTransferSrcVectorDim;
ck::index_t BBlockTransferSrcScalarPerVector;
ck::index_t BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 8> CThreadTransferSrcDstAccessOrder;
ck::index_t CThreadTransferSrcDstVectorDim;
ck::index_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw
{
ck::index_t BlockSize;
ck::index_t GM1PerBlockGM11;
ck::index_t GN1PerBlockGN11;
ck::index_t KPerBlock;
ck::index_t M1PerThread;
ck::index_t N1PerThread;
ck::index_t KPerThread;
ck::index_t M1N1ThreadClusterM10;
ck::index_t M1N1ThreadClusterN10;
ck::index_t M1N1ThreadClusterM11;
ck::index_t M1N1ThreadClusterN11;
std::array<ck::index_t, 4> ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11;
std::array<ck::index_t, 4> ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11;
std::array<ck::index_t, 4> ABlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 4> ABlockTransferSrcAccessOrder;
ck::index_t ABlockTransferSrcVectorDim;
ck::index_t ABlockTransferSrcScalarPerVector;
ck::index_t ABlockTransferDstScalarPerVector_GM11;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 4> BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11;
std::array<ck::index_t, 4> BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11;
std::array<ck::index_t, 4> BBlockTransferThreadClusterArrangeOrder;
std::array<ck::index_t, 4> BBlockTransferSrcAccessOrder;
ck::index_t BBlockTransferSrcVectorDim;
ck::index_t BBlockTransferSrcScalarPerVector;
ck::index_t BBlockTransferDstScalarPerVector_GN11;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<ck::index_t, 6> CThreadTransferSrcDstAccessOrder;
ck::index_t CThreadTransferSrcDstVectorDim;
ck::index_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw = {
256,
128,
32,
8,
4,
4,
1,
2,
2,
8,
8,
{4, 1, 1, 1},
{2, 1, 1, 128},
{3, 2, 1, 0},
{3, 2, 1, 0},
0,
4,
1,
false,
{1, 4, 1, 1},
{8, 1, 1, 32},
{0, 3, 2, 1},
{0, 3, 2, 1},
3,
1,
1,
false,
{3, 4, 5, 0, 1, 2},
5,
1};
static inline int
conv_hw_out_size(int hw_in_size, int leftPad, int rightPad, int dilation, int yx_size, int stride)
{
return (hw_in_size + leftPad + rightPad - dilation * (yx_size - 1) - 1) / stride + 1;
}
#endif

View File

@@ -1,520 +0,0 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_v1r2.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 0
// cdata = 16, BlockSize = 64, 16x64x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlockM1 = 16;
constexpr index_t GemmNPerBlockN1 = 64;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmM1PerThreadM111 = 2;
constexpr index_t GemmN1PerThreadN111 = 2;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2;
#elif 0
// cdata = 32, BlockSize = 64, 16x128x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlockM1 = 16;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmM1PerThreadM111 = 2;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2;
#elif 0
// cdata = 64, BlockSize = 64, 16x256x2
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlockM1 = 16;
constexpr index_t GemmNPerBlockN1 = 256;
constexpr index_t GemmKPerBlock = 2;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 1;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<2, 1, 4>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 2;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
#elif 0
// cdata = 64, BlockSize = 64, 16x256x4
constexpr index_t BlockSize = 64;
constexpr index_t GemmMPerBlockM1 = 16;
constexpr index_t GemmNPerBlockN1 = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 1;
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 4>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x4
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlockM1 = 32;
constexpr index_t GemmNPerBlockN1 = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
#elif 0
// cdata = 64, BlockSize = 128, 32x256x8
constexpr index_t BlockSize = 128;
constexpr index_t GemmMPerBlockM1 = 32;
constexpr index_t GemmNPerBlockN1 = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 2;
constexpr index_t GemmM11N11ThreadClusterN1100 = 16;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<2, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 2;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 2>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x8
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlockM1 = 128;
constexpr index_t GemmNPerBlockN1 = 128;
constexpr index_t GemmKPerBlock = 16;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 2>;
using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 64>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 2;
using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4;
#endif
#if 1
const auto descs =
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
#if 0
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}));
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
#else
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{};
#endif
#else
const auto descs =
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{}));
constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
#endif
const auto wei_gemmk_gemmm_grid_desc = descs[I0];
const auto in_gemmk_gemmn_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_gemm_v1r2<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(wei_gemmk_gemmm_grid_desc),
decltype(in_gemmk_gemmn_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlockM1,
GemmNPerBlockN1,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM1100,
GemmM11N11ThreadClusterN1100,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_K_M0_M1,
GemmABlockTransferThreadClusterLengths_K_M0_M1,
Sequence<1, 2, 0>, // ABlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0>, // ABlockTransferSrcAccessOrder
0, // ABlockTransferSrcVectorDim
GemmABlockTransferSrcScalarPerVector_K,
GemmABlockTransferDstScalarPerVector_M1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_K_N0_N1,
GemmBBlockTransferThreadClusterLengths_K_N0_N1,
Sequence<1, 2, 0>, // BBlockTransferThreadClusterArrangeOrder
Sequence<1, 2, 0>, // BBlockTransferSrcAccessOrder
0, // BBlockTransferSrcVectorDim
GemmBBlockTransferSrcScalarPerVector_K,
GemmBBlockTransferDstScalarPerVector_N1,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
2, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_M11,
decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks),
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>(
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
wei_gemmk_gemmm_grid_desc,
in_gemmk_gemmn_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks,
in_gemmk_gemmn0_gemmn1_grid_iterator_hacks,
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks,
in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -1,240 +0,0 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_contraction_v1r1.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
#if 1
// cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 4, 32] = [1, 128, 4, 32]
constexpr index_t BlockSize = 256;
constexpr index_t N0 = 4;
constexpr index_t GemmGM1PerBlockGM11 = 128;
constexpr index_t GemmGN1PerBlockGN11 = 32;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1;
using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 1, 1, 32>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
#elif 1
// cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 8, 16] = [1, 128, 8, 16]
constexpr index_t BlockSize = 256;
constexpr index_t N0 = 8;
constexpr index_t GemmGM1PerBlockGM11 = 128;
constexpr index_t GemmGN1PerBlockGN11 = 16;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1;
using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 2, 1, 16>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1;
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
#endif
const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
const auto wei_gk_gm0_gm1_grid_desc = descs[I0];
const auto in_gk_gn0_gn1_grid_desc = descs[I1];
const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gk_gm0_gm10_gm11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0>{}));
constexpr auto in_gk_gn0_gn10_gn11_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}));
constexpr auto wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0>{};
constexpr auto in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_contraction_v1r1<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(wei_gk_gm0_gm1_grid_desc),
decltype(in_gk_gn0_gn1_grid_desc),
decltype(out_gm0_gm1_gn0_gn1_grid_desc),
GemmGM1PerBlockGM11,
GemmGN1PerBlockGN11,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM1100,
GemmM11N11ThreadClusterN1100,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11,
GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11,
Sequence<3, 2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder
Sequence<3, 2, 1, 0>, // ABlockTransferSrcAccessOrder
0, // ABlockTransferSrcVectorDim
GemmABlockTransferSrcScalarPerVector_GK,
GemmABlockTransferDstScalarPerVector_GM11,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11,
GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11,
Sequence<0, 3, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
Sequence<0, 3, 2, 1>, // BBlockTransferSrcAccessOrder
3, // BBlockTransferSrcVectorDim
GemmBBlockTransferSrcScalarPerVector_GN11,
GemmBBlockTransferDstScalarPerVector_GN11,
false, // don't move back src coordinate after threadwise copy
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_BN1,
decltype(wei_gk_gm0_gm10_gm11_grid_iterator_hacks),
decltype(in_gk_gn0_gn10_gn11_grid_iterator_hacks),
decltype(out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks),
decltype(wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks),
decltype(in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks)>(
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gk_gm0_gm1_grid_desc,
in_gk_gn0_gn1_grid_desc,
out_gm0_gm1_gn0_gn1_grid_desc,
wei_gk_gm0_gm10_gm11_grid_iterator_hacks,
in_gk_gn0_gn10_gn11_grid_iterator_hacks,
out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks,
wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks,
in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks,
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -1,404 +0,0 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "olc_driver_common.hpp"
#include "conv_tunables.hpp"
#include "handle.hpp"
namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw {
template <typename TInWei, typename TAcc, typename TOut>
static std::string get_network_config_string_from_types()
{
std::string out;
out += static_cast<char>(Driver::get_typeid_from_type<TInWei>()) +
static_cast<char>(Driver::get_typeid_from_type<TAcc>()) +
static_cast<char>(Driver::get_typeid_from_type<TOut>());
return (out);
};
static std::string
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* pt)
{
std::string out("TUN_");
out += std::to_string(pt->BlockSize) + "_";
out += std::to_string(pt->GM1PerBlockGM11) + "x" + std::to_string(pt->GN1PerBlockGN11) + "x" +
std::to_string(pt->KPerBlock) + "_";
out += std::to_string(pt->M1PerThread) + "x" + std::to_string(pt->N1PerThread) + "x" +
std::to_string(pt->KPerThread) + "_";
out += std::to_string(pt->M1N1ThreadClusterM10) + "x" +
std::to_string(pt->M1N1ThreadClusterN10) + "x" +
std::to_string(pt->M1N1ThreadClusterM11) + "x" +
std::to_string(pt->M1N1ThreadClusterN11) + "_";
out += std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[0]) + "x" +
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[1]) + "x" +
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[2]) + "x" +
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[3]) + "_";
out += std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[0]) + "x" +
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[1]) + "x" +
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[2]) + "x" +
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[3]) + "_";
out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" +
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" +
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "x" +
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[3]) + "_";
out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" +
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" +
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "x" +
std::to_string(pt->ABlockTransferSrcAccessOrder[3]) + "_";
out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_";
out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_";
out += std::to_string(pt->ABlockTransferDstScalarPerVector_GM11) + "_";
out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_";
out += std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[0]) + "x" +
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[1]) + "x" +
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[2]) + "x" +
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[3]);
out += std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[0]) + "x" +
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[1]) + "x" +
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[2]) + "x" +
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[3]) + "_";
out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" +
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" +
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "x" +
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[3]) + "_";
out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" +
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" +
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "x" +
std::to_string(pt->BBlockTransferSrcAccessOrder[3]) + "_";
out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_";
out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_";
out += std::to_string(pt->BBlockTransferDstScalarPerVector_GN11) + "_";
out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_";
out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "_";
out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_";
out += std::to_string(pt->CThreadTransferDstScalarPerVector);
return (out);
};
template <typename TInWei, typename TAcc, typename TOut>
static std::string get_definition_string_from_types()
{
std::string out;
out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TInWei>()) +
" -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type<TAcc>()) +
" -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TOut>());
return (out);
};
static std::string
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* pt)
{
std::string out;
out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize);
out += " -DCK_PARAM_GM1PerBlockGM11=" + std::to_string(pt->GM1PerBlockGM11) +
" -DCK_PARAM_GN1PerBlockGN11=" + std::to_string(pt->GN1PerBlockGN11) +
" -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock);
out += " -DCK_PARAM_M1PerThread=" + std::to_string(pt->M1PerThread) +
" -DCK_PARAM_N1PerThread=" + std::to_string(pt->N1PerThread) +
" -DCK_PARAM_KPerThread=" + std::to_string(pt->KPerThread);
out += " -DCK_PARAM_M1N1ThreadClusterM10=" + std::to_string(pt->M1N1ThreadClusterM10) +
" -DCK_PARAM_M1N1ThreadClusterN10=" + std::to_string(pt->M1N1ThreadClusterN10) +
" -DCK_PARAM_M1N1ThreadClusterM11=" + std::to_string(pt->M1N1ThreadClusterM11) +
" -DCK_PARAM_M1N1ThreadClusterN11=" + std::to_string(pt->M1N1ThreadClusterN11);
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11=" +
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[0]) + "," +
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[1]) + "," +
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[2]) + "," +
std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[3]);
out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11=" +
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[0]) + "," +
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[1]) + "," +
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[2]) + "," +
std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[3]);
out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" +
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," +
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," +
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "," +
std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[3]);
out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" +
std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," +
std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," +
std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "," +
std::to_string(pt->ABlockTransferSrcAccessOrder[3]);
out +=
" -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim);
out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" +
std::to_string(pt->ABlockTransferSrcScalarPerVector);
out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_GM11=" +
std::to_string(pt->ABlockTransferDstScalarPerVector_GM11);
out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" +
std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun);
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11=" +
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[0]) + "," +
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[1]) + "," +
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[2]) + "," +
std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[3]);
out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11=" +
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[0]) + "," +
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[1]) + "," +
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[2]) + "," +
std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[3]);
out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" +
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," +
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," +
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "," +
std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[3]);
out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" +
std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," +
std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," +
std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "," +
std::to_string(pt->BBlockTransferSrcAccessOrder[3]);
out +=
" -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim);
out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" +
std::to_string(pt->BBlockTransferSrcScalarPerVector);
out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_GN11=" +
std::to_string(pt->BBlockTransferDstScalarPerVector_GN11);
out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" +
std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun);
out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," +
std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]);
out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" +
std::to_string(pt->CThreadTransferSrcDstVectorDim);
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
std::to_string(pt->CThreadTransferDstScalarPerVector);
return (out);
};
} // namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_olc(
olCompile::Handle* handle,
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* tunable,
ck::index_t nrepeat)
{
using namespace ck;
using namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw;
using size_t = std::size_t;
constexpr index_t N0 = 4; // this could not be a tunable so far
////////////////////////////////////////////////////////////////////////////////////////////////////////////
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
// hasDoubleTailKBlockLoop
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad<N0>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
const auto a_gk_gm0_gm1_grid_desc = descs[I0];
const auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0);
const index_t grid_size = (GM1 / tunable->GM1PerBlockGM11) * (GN1 / tunable->GN1PerBlockGN11);
const bool hasMainKBlockLoop = ((GK + tunable->KPerBlock) / (2 * tunable->KPerBlock) > 1);
const bool hasDoubleTailKBlockLoop = ((GK / tunable->KPerBlock) % 2 == 0);
///////////////////////////////////////////////////////////////////////////////////////////////////////////
// these buffers are usually provided by the user application
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
// these are workspace buffers that should be expressed to the user by the corresponding
// workspace API
DeviceMem workspace_buf(4096);
void* a_gk_gm0_gm10_gm11_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer();
void* b_gk_gn0_gn10_gn11_grid_desc_dev_buf =
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
void* c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf =
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
void* c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf =
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
const std::vector<size_t> vld = {static_cast<size_t>(tunable->BlockSize), 1, 1};
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable->BlockSize), 1, 1};
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable->BlockSize), 1, 1};
std::string program_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp";
std::string algo_name = "implicit_gemm_conv_fwd_v4r4_nchw";
std::string param = " -std=c++17 ";
std::string network_config;
param += get_definition_string_from_types<TInWei, TAcc, TOut>() +
" -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) +
" -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop) +
" -DCK_PARAM_N0=" + std::to_string(N0) + " " +
get_definition_string_from_tunable(tunable);
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_V" +
std::to_string(hasDoubleTailKBlockLoop) + "_" + std::to_string(N0) + "_" +
get_network_config_string_from_tunable(tunable);
std::vector<float> kernel1_times;
std::vector<float> kernel2_times;
for(index_t i = 0; i < nrepeat; ++i)
{
KernelTimer timer1, timer2;
std::string kernel_name;
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_prepare";
auto network_config_1 = network_config + "_1";
timer1.Start();
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
conv_strides[I0],
conv_strides[I1],
conv_dilations[I0],
conv_dilations[I1],
in_left_pads[I0],
in_left_pads[I1],
in_right_pads[I0],
in_right_pads[I1],
a_gk_gm0_gm10_gm11_grid_desc_dev_buf,
b_gk_gn0_gn10_gn11_grid_desc_dev_buf,
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf,
c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf);
timer2.End();
kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw";
auto network_config_2 = network_config + "_2";
timer2.Start();
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
(const void*)(a_gk_gm0_gm10_gm11_grid_desc_dev_buf),
(const void*)(b_gk_gn0_gn10_gn11_grid_desc_dev_buf),
(const void*)(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf),
(const void*)(c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf));
timer2.End();
kernel1_times.push_back(timer1.GetElapsedTime());
kernel2_times.push_back(timer2.GetElapsedTime());
}
{
auto ave_time1 = Driver::get_effective_average(kernel1_times);
auto ave_time2 = Driver::get_effective_average(kernel2_times);
const auto N = in_n_c_hi_wi_lengths[I0];
const auto C = in_n_c_hi_wi_lengths[I1];
const auto K = out_n_k_ho_wo_lengths[I1];
const auto Ho = out_n_k_ho_wo_lengths[I2];
const auto Wo = out_n_k_ho_wo_lengths[I3];
const auto Y = wei_k_c_y_x_lengths[I2];
const auto X = wei_k_c_y_x_lengths[I3];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
};
// copy result back to host
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

4
host/CMakeLists.txt Normal file
View File

@@ -0,0 +1,4 @@
add_subdirectory(host_tensor)
add_subdirectory(online_compilation)
add_subdirectory(driver_offline)
add_subdirectory(driver_online)

View File

@@ -0,0 +1,21 @@
include_directories(BEFORE
include
${PROJECT_SOURCE_DIR}/host/host_tensor/include
${PROJECT_SOURCE_DIR}/composable_kernel/include
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/rocm/include
${PROJECT_SOURCE_DIR}/external/half/include
)
set(CONV_FWD_DRIVER_OFFLINE_SOURCE conv_fwd_driver_offline.cpp)
set(CONV_BWD_DRIVER_OFFLINE_SOURCE conv_bwd_driver_offline.cpp)
add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE})
add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE})
target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor)
target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor)

View File

@@ -13,34 +13,28 @@
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE 1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4_NHWC 0
#define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V4R5_NCHW 0
#define USE_CONV_FWD_V4R5R2_NCHW 0
#define USE_CONV_FWD_V6R1_NCHW 0
#define USE_CONV_FWD_V5R1_NCHW 0
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
enum ConvForwardAlgo
{
V4R4NCHW, // 0
V4R4NHWC, // 1
V4R4R2NHWC, // 2
V4R5NCHW, // 3
V4R5R2NCHW, // 4
V5R1NCHW, // 5
V4R4R2XDLNCHW, // 6
V4R4R4XDLNHWC // 7
V4R4R2NHWC, // 1
V6R1NCHW, // 2
V5R1NCHW, // 3
V4R4R2XDLNCHW, // 4
V4R4R4XDLNHWC // 5
};
int main(int argc, char* argv[])
@@ -132,7 +126,7 @@ int main(int argc, char* argv[])
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#endif
#if 0
#if 1
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
@@ -323,32 +317,6 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_FWD_V4R4_NHWC
if(algo == ConvForwardAlgo::V4R4NHWC)
{
if(layout != ConvTensorLayout::NHWC)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nhwc();
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V4R4R2_NHWC
if(algo == ConvForwardAlgo::V4R4R2NHWC)
{
@@ -376,8 +344,8 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_FWD_V4R5_NCHW
if(algo == ConvForwardAlgo::V4R5NCHW)
#if USE_CONV_FWD_V6R1_NCHW
if(algo == ConvForwardAlgo::V6R1NCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
@@ -386,7 +354,7 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw<in_data_t,
device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(tmp[I0],
tmp[I1],
@@ -402,33 +370,6 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_FWD_V4R5R2_NCHW
if(algo == ConvForwardAlgo::V4R5R2NCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
throw std::runtime_error("wrong! layout");
}
const auto tmp = f_make_for_device_nchw();
device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(
tmp[I0],
tmp[I1],
tmp[I2],
tmp[I3],
tmp[I4],
tmp[I5],
tmp[I6],
in,
wei,
out_device,
nrepeat);
}
#endif
#if USE_CONV_FWD_V5R1_NCHW
if(algo == ConvForwardAlgo::V5R1NCHW)
{

View File

@@ -0,0 +1,283 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
#if 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 0
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4]
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmKPack = 4;
constexpr index_t MRepeat = 1;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1;
#endif
const auto descs =
#if 1
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
#else
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1
#endif
<TInWei, GemmMPerBlock, GemmNPerBlock, GemmMPerWave, GemmNPerWave, GemmKPack>(
wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
for(index_t i = 0; i < 5; ++i)
{
#if 0
float ave_time = launch_kernel_dynamic_gemm_xdlops_v1
#else
float ave_time = launch_kernel_dynamic_gemm_xdlops_v2
#endif
<BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(descs[I0]),
decltype(descs[I1]),
decltype(descs[I2]),
decltype(descs[I3]),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
GemmKPack,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK,
GemmABlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<0, 2, 1>,
Sequence<1, 0, 2>,
1,
GemmBBlockTransferSrcScalarPerVector_GemmN,
GemmBBlockTransferDstScalarPerVector_KPack,
false, // don't move back src coordinate after threadwise copy, which will be fused
// with MoveSrcSliceWindow() to save addr computation
Sequence<2, 3, 0, 1>,
3,
GemmCThreadTransferDstScalarPerVector_GemmN1,
decltype(descs[I4]),
decltype(descs[I5]),
decltype(descs[I6]),
decltype(descs[I7]),
decltype(descs[I8])>(static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
descs[I0],
descs[I1],
descs[I2],
descs[I3],
descs[I4],
descs[I5],
descs[I6],
descs[I7],
descs[I8],
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
}
// copy result back to host
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,240 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_xdlops_v2r2.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 1;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_gemm_xdlops_v2r2<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1>,
2,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>(
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
out_m0_m1_m2_n_grid_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -0,0 +1,305 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
#include "driver_dynamic_gemm_xdlops_v2r3.hpp"
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk(
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,
const OutLengths& out_n_ho_wo_k_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_hi_wi_c,
const Tensor<TInWei>& wei_k_y_x_c,
Tensor<TOut>& out_n_ho_wo_k,
ck::index_t nrepeat)
{
using namespace ck;
std::cout << __func__ << std::endl;
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto I4 = Number<4>{};
constexpr auto I5 = Number<5>{};
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
constexpr auto I8 = Number<8>{};
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
const auto in_n_hi_wi_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
const auto wei_k_y_x_c_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
const auto out_n_ho_wo_k_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
#if 1
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 4;
constexpr index_t MRepeat = 2;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 0
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 4;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#elif 1
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 256;
constexpr index_t GemmNPerBlock = 128;
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmMPerWave = 32;
constexpr index_t GemmNPerWave = 32;
constexpr index_t GemmK1 = 8;
constexpr index_t MRepeat = 4;
constexpr index_t NRepeat = 2;
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
#endif
const auto descs =
transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc,
in_n_hi_wi_c_desc,
out_n_ho_wo_k_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GemmK1>{});
const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
const auto out_gemmm_gemmn_grid_desc = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple(
make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}),
make_tuple(
Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}));
constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}));
constexpr auto out_m0_m1_m2_n_grid_iterator_hacks =
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 1, 0, 0>{}),
make_tuple(Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 0, 0, 0>{},
Sequence<0, 0, 2, 0, 0>{}));
constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0>{};
constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
for(index_t i = 0; i < 5; ++i)
{
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
BlockSize,
TInWei,
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(wei_gemmk0_gemmm_gemmk1_grid_desc),
decltype(in_gemmk0_gemmn_gemmk1_grid_desc),
decltype(out_gemmm_gemmn_grid_desc),
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerWave,
GemmNPerWave,
MRepeat,
NRepeat,
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmABlockTransferSrcScalarPerVector_GemmK1,
GemmABlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1,
GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1,
Sequence<1, 0, 2>,
Sequence<1, 0, 2>,
2,
GemmBBlockTransferSrcScalarPerVector_GemmK1,
GemmBBlockTransferDstScalarPerVector_GemmK1,
false, // don't move back src coordinate after threadwise copy
Sequence<2, 3, 0, 1, 7, 5, 4, 6>,
6,
GemmCThreadTransferDstScalarPerVector,
decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks),
decltype(out_m0_m1_m2_n_grid_iterator_hacks),
decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks),
decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks),
false // CAccessOrderMRepeatNRepeat
>(static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
wei_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmk0_gemmn_gemmk1_grid_desc,
out_gemmm_gemmn_grid_desc,
wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks,
in_gemmk0_gemmn_gemmk1_grid_iterator_hacks,
out_m0_m1_m2_n_grid_iterator_hacks,
wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks,
in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks,
nrepeat);
{
const auto N = out_n_ho_wo_k_lengths[I0];
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
const auto Y = wei_k_y_x_c_lengths[I1];
const auto X = wei_k_y_x_c_lengths[I2];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
<< std::endl;
}
}
// copy result back to host
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
}

View File

@@ -1,7 +1,7 @@
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_contraction_v1r2.hpp"
template <typename TInWei,
@@ -14,7 +14,7 @@ template <typename TInWei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
@@ -43,11 +43,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
const auto in_n_c_hi_wi_desc =
const auto in_desc_n_c_hi_wi =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc =
const auto wei_desc_k_c_y_x =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc =
const auto out_desc_n_k_ho_wo =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
#if 1
@@ -58,32 +58,32 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
constexpr index_t GN0 = 4;
constexpr index_t GK1 = 1;
constexpr index_t GemmGM1PerBlockGM11 = 128;
constexpr index_t GemmGN1PerBlockGN11 = 32;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GM1PerBlockGM11 = 128;
constexpr index_t GN1PerBlockGN11 = 32;
constexpr index_t GK0PerBlock = 8;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t BM1PerThreadBM11 = 4;
constexpr index_t BN1PerThreadBN11 = 4;
constexpr index_t BK0PerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t BM10BN10ThreadClusterBM100 = 8;
constexpr index_t BM10BN10ThreadClusterBN100 = 8;
constexpr index_t BM10BN10ThreadClusterBM101 = 2;
constexpr index_t BM10BN10ThreadClusterBN101 = 2;
using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
#elif 1
// [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16
// cdata = 64, BlockSize = 256
@@ -92,48 +92,48 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
constexpr index_t GN0 = 4;
constexpr index_t GK1 = 2;
constexpr index_t GemmGM1PerBlockGM11 = 128;
constexpr index_t GemmGN1PerBlockGN11 = 32;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GM1PerBlockGM11 = 128;
constexpr index_t GN1PerBlockGN11 = 32;
constexpr index_t GK0PerBlock = 8;
constexpr index_t GemmM1PerThreadM111 = 4;
constexpr index_t GemmN1PerThreadN111 = 4;
constexpr index_t GemmKPerThread = 1;
constexpr index_t BM1PerThreadBM11 = 4;
constexpr index_t BN1PerThreadBN11 = 4;
constexpr index_t BK0PerThread = 1;
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
constexpr index_t BM10BN10ThreadClusterBM100 = 8;
constexpr index_t BM10BN10ThreadClusterBN100 = 8;
constexpr index_t BM10BN10ThreadClusterBM101 = 2;
constexpr index_t BM10BN10ThreadClusterBN101 = 2;
using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1;
#endif
const auto descs =
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GN0>{},
Number<GK1>{});
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x,
in_desc_n_c_hi_wi,
out_desc_n_k_ho_wo,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
Number<GN0>{},
Number<GK1>{});
const auto wei_gk0_gm0_gm1_gk1_grid_desc = descs[I0];
const auto in_gk0_gn0_gn1_gk1_grid_desc = descs[I1];
const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1];
const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
// HACK: hacks that control index calculation when iterating over A, B, C matrix
constexpr auto wei_grid_iterator_hacks =
@@ -189,36 +189,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
TAcc,
TOut,
InMemoryDataOperation::Set,
decltype(wei_gk0_gm0_gm1_gk1_grid_desc),
decltype(in_gk0_gn0_gn1_gk1_grid_desc),
decltype(out_gm0_gm1_gn0_gn1_grid_desc),
GemmGM1PerBlockGM11,
GemmGN1PerBlockGN11,
GemmKPerBlock,
GemmM1PerThreadM111,
GemmN1PerThreadN111,
GemmKPerThread,
GemmM11N11ThreadClusterM1100,
GemmM11N11ThreadClusterN1100,
GemmM11N11ThreadClusterM1101,
GemmM11N11ThreadClusterN1101,
GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
decltype(wei_grid_desc_gk0_gm0_gm1_gk1),
decltype(in_grid_desc_gk0_gn0_gn1_gk1),
decltype(out_grid_desc_gm0_gm1_gn0_gn1),
GM1PerBlockGM11,
GN1PerBlockGN11,
GK0PerBlock,
BM1PerThreadBM11,
BN1PerThreadBN11,
BK0PerThread,
BM10BN10ThreadClusterBM100,
BM10BN10ThreadClusterBN100,
BM10BN10ThreadClusterBM101,
BM10BN10ThreadClusterBN101,
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder
GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
5, // CThreadTransferSrcDstVectorDim
GemmCThreadTransferDstScalarPerVector_BN1,
CThreadTransferDstScalarPerVector_BN1,
decltype(wei_grid_iterator_hacks),
decltype(in_grid_iterator_hacks),
decltype(out_grid_iterator_hacks),
@@ -227,9 +227,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
wei_gk0_gm0_gm1_gk1_grid_desc,
in_gk0_gn0_gn1_gk1_grid_desc,
out_gm0_gm1_gn0_gn1_grid_desc,
wei_grid_desc_gk0_gm0_gm1_gk1,
in_grid_desc_gk0_gn0_gn1_gk1,
out_grid_desc_gm0_gm1_gn0_gn1,
wei_grid_iterator_hacks,
in_grid_iterator_hacks,
out_grid_iterator_hacks,
@@ -238,7 +238,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
nrepeat);
float perf = (float)calculate_convolution_flops(
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo) /
(std::size_t(1000) * 1000 * 1000) / ave_time;
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;

View File

@@ -0,0 +1,21 @@
include_directories(BEFORE
include
${PROJECT_BINARY_DIR}/host/online_compilation/include
${PROJECT_SOURCE_DIR}/host/online_compilation/include
${PROJECT_SOURCE_DIR}/host/host_tensor/include
${PROJECT_SOURCE_DIR}/composable_kernel/include
${PROJECT_SOURCE_DIR}/composable_kernel/include/utility
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description
${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation
${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform
${PROJECT_SOURCE_DIR}/composable_kernel/include/driver
${PROJECT_SOURCE_DIR}/external/rocm/include
${PROJECT_SOURCE_DIR}/external/half/include
)
set(CONV_FWD_DRIVER_ONLINE_SOURCE conv_fwd_driver_online.cpp)
add_executable(conv_fwd_driver_online ${CONV_FWD_DRIVER_ONLINE_SOURCE})
target_link_libraries(conv_fwd_driver_online PRIVATE host_tensor)
target_link_libraries(conv_fwd_driver_online PRIVATE online_compilation)

View File

@@ -12,26 +12,22 @@
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V4R5_NCHW 1
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
#include "conv_tunables.hpp"
#include "handle.hpp"
#include "hipCheck.hpp"
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_CONV_FWD_V4R4_NCHW 1
#define USE_CONV_FWD_V6R1_NCHW 1
#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1
#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1
enum ConvForwardAlgo
{
V4R4NCHW, // 0
V4R5NCHW, // 1
V6R1NCHW, // 1
V4R4XDLNCHW, // 2
V4R4XDLNHWC // 3
};
@@ -94,15 +90,17 @@ int main(int argc, char* argv[])
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
#if 1
constexpr index_t in_vector_size = 1;
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 1
constexpr index_t in_vector_size = 16;
using in_data_t = int8_t;
using acc_data_t = int32_t;
using out_data_t = int8_t;
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
#elif 1
using in_data_t = int8_t;
using acc_data_t = int32_t;
using out_data_t = int8_t;
#endif
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
@@ -230,9 +228,9 @@ int main(int argc, char* argv[])
tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw* tunable =
&default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw;
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_olc<in_data_t,
acc_data_t,
out_data_t>(
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(
handle,
tmp[I0],
tmp[I1],
@@ -249,8 +247,8 @@ int main(int argc, char* argv[])
}
#endif
#if USE_CONV_FWD_V4R5_NCHW
if(algo == ConvForwardAlgo::V4R5NCHW)
#if USE_CONV_FWD_V6R1_NCHW
if(algo == ConvForwardAlgo::V6R1NCHW)
{
if(layout != ConvTensorLayout::NCHW)
{
@@ -259,12 +257,11 @@ int main(int argc, char* argv[])
const auto tmp = f_make_for_device_nchw();
tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* tunable =
&default_tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw;
const auto tunable = tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw{};
device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_olc<in_data_t,
acc_data_t,
out_data_t>(
online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw<in_data_t,
acc_data_t,
out_data_t>(
handle,
tmp[I0],
tmp[I1],
@@ -294,22 +291,22 @@ int main(int argc, char* argv[])
tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable =
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc<in_data_t,
acc_data_t,
out_data_t>(
handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
tunable,
nrepeat);
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw<
in_data_t,
acc_data_t,
out_data_t>(handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
tunable,
nrepeat);
}
#endif
@@ -326,22 +323,22 @@ int main(int argc, char* argv[])
tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable =
&default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk;
device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc<in_data_t,
acc_data_t,
out_data_t>(
handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
tunable,
nrepeat);
online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk<
in_data_t,
acc_data_t,
out_data_t>(handle,
tmp[I0],
tmp[I1],
tmp[I2],
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
in,
wei,
out_device,
tunable,
nrepeat);
}
#endif

View File

@@ -0,0 +1,50 @@
#ifndef CONV_TUNABLE_FWD_V4R4_NCHW_KCYX_NKHW_HPP
#define CONV_TUNABLE_FWD_V4R4_NCHW_KCYX_NKHW_HPP
struct tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw
{
int32_t BlockSize;
int32_t MPerBlock;
int32_t NPerBlock;
int32_t KPerBlock;
int32_t M1PerThread;
int32_t N1PerThread;
int32_t KPerThread;
int32_t M1N1ThreadClusterM10;
int32_t M1N1ThreadClusterN10;
int32_t M1N1ThreadClusterM11;
int32_t M1N1ThreadClusterN11;
std::array<int32_t, 3> ABlockTransferThreadSliceLengths_K_M0_M1;
std::array<int32_t, 3> ABlockTransferThreadClusterLengths_K_M0_M1;
std::array<int32_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<int32_t, 3> ABlockTransferSrcAccessOrder;
int32_t ABlockTransferSrcVectorDim;
int32_t ABlockTransferSrcScalarPerVector;
int32_t ABlockTransferDstScalarPerVector_M1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<int32_t, 3> BBlockTransferThreadSliceLengths_K_N0_N1;
std::array<int32_t, 3> BBlockTransferThreadClusterLengths_K_N0_N1;
std::array<int32_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<int32_t, 3> BBlockTransferSrcAccessOrder;
int32_t BBlockTransferSrcVectorDim;
int32_t BBlockTransferSrcScalarPerVector;
int32_t BBlockTransferDstScalarPerVector_N1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<int32_t, 6> CThreadTransferSrcDstAccessOrder;
int32_t CThreadTransferSrcDstVectorDim;
int32_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw = {
256, 128, 128, 8, 4, 4, 1,
8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0},
{2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128},
{0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2},
5, 1};
#endif

View File

@@ -0,0 +1,73 @@
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
{
int32_t BlockSize;
int32_t MPerBlock;
int32_t NPerBlock;
int32_t KPerBlock;
int32_t MPerWave;
int32_t NPerWave;
int32_t K1;
int32_t MRepeat;
int32_t NRepeat;
std::array<int32_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<int32_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<int32_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<int32_t, 3> ABlockTransferSrcAccessOrder;
int32_t ABlockTransferSrcVectorDim;
int32_t ABlockTransferSrcScalarPerVector;
int32_t ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<int32_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<int32_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<int32_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<int32_t, 3> BBlockTransferSrcAccessOrder;
int32_t BBlockTransferSrcVectorDim;
int32_t BBlockTransferSrcScalarPerVector;
int32_t BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<int32_t, 8> CThreadTransferSrcDstAccessOrder;
int32_t CThreadTransferSrcDstVectorDim;
int32_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw
default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
1, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
#endif

View File

@@ -0,0 +1,73 @@
#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP
struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
{
int32_t BlockSize;
int32_t MPerBlock;
int32_t NPerBlock;
int32_t KPerBlock;
int32_t MPerWave;
int32_t NPerWave;
int32_t K1;
int32_t MRepeat;
int32_t NRepeat;
std::array<int32_t, 3> ABlockTransferThreadSliceLengths_K0_M_K1;
std::array<int32_t, 3> ABlockTransferThreadClusterLengths_K0_M_K1;
std::array<int32_t, 3> ABlockTransferThreadClusterArrangeOrder;
std::array<int32_t, 3> ABlockTransferSrcAccessOrder;
int32_t ABlockTransferSrcVectorDim;
int32_t ABlockTransferSrcScalarPerVector;
int32_t ABlockTransferDstScalarPerVector_K1;
bool AThreadTransferSrcResetCoordinateAfterRun;
std::array<int32_t, 3> BBlockTransferThreadSliceLengths_K0_N_K1;
std::array<int32_t, 3> BBlockTransferThreadClusterLengths_K0_N_K1;
std::array<int32_t, 3> BBlockTransferThreadClusterArrangeOrder;
std::array<int32_t, 3> BBlockTransferSrcAccessOrder;
int32_t BBlockTransferSrcVectorDim;
int32_t BBlockTransferSrcScalarPerVector;
int32_t BBlockTransferDstScalarPerVector_K1;
bool BThreadTransferSrcResetCoordinateAfterRun;
std::array<int32_t, 8> CThreadTransferSrcDstAccessOrder;
int32_t CThreadTransferSrcDstVectorDim;
int32_t CThreadTransferDstScalarPerVector;
};
static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk
default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = {
256, // BlockSize
128, // MPerBlock,
128, // NPerBlock,
4, // KPerBlock,
32, // MPerWave,
32, // NPerWave,
4, // K1,
2, // MRepeat,
2, // NRepeat,
{1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1,
{4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1,
{1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // ABlockTransferSrcAccessOrder,
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector,
4, // ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
{1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1,
{4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1,
{1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder,
{1, 0, 2}, // BBlockTransferSrcAccessOrder,
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_K1
false, // BThreadTransferSrcResetCoordinateAfterRun
{2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder
7, // CThreadTransferSrcDstVectorDim,
1 // CThreadTransferDstScalarPerVector
};
#endif

View File

@@ -0,0 +1,42 @@
#ifndef CONV_TUNABLE_FWD_V6R1_NCHW_KCYX_NKHW_HPP
#define CONV_TUNABLE_FWD_V6R1_NCHW_KCYX_NKHW_HPP
struct tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw
{
int32_t BlockSize = 256;
int32_t GN0 = 4;
int32_t GK1 = 1;
int32_t GM1PerBlockGM11 = 128;
int32_t GN1PerBlockGN11 = 32;
int32_t GK0PerBlock = 8;
int32_t BM1PerThreadBM11 = 4;
int32_t BN1PerThreadBN11 = 4;
int32_t BK0PerThread = 1;
int32_t BM10BN10ThreadClusterBM100 = 2;
int32_t BM10BN10ThreadClusterBN100 = 2;
int32_t BM10BN10ThreadClusterBM101 = 8;
int32_t BM10BN10ThreadClusterBN101 = 8;
std::array<int32_t, 5> ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = {4, 1, 1, 1, 1};
std::array<int32_t, 5> ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = {
2, 1, 1, 128, 1};
std::array<int32_t, 5> ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
4, 1, 1, 1, 1};
std::array<int32_t, 5> ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = {
1, 1, 1, 1, 1};
std::array<int32_t, 5> BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = {1, 4, 1, 1, 1};
std::array<int32_t, 5> BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = {
8, 1, 1, 32, 1};
std::array<int32_t, 5> BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
1, 1, 1, 1, 1};
std::array<int32_t, 5> BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = {
1, 1, 1, 1, 1};
int32_t CThreadTransferDstScalarPerVector = 1;
};
#endif

View File

@@ -1,13 +1,11 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "handle.hpp"
#include "online_driver_common.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "olc_driver_common.hpp"
#include "conv_tunables.hpp"
#include "handle.hpp"
#include "conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp"
namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw {
@@ -211,7 +209,7 @@ template <typename TInWei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_olc(
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
olCompile::Handle* handle,
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,

View File

@@ -1,12 +1,10 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "handle.hpp"
#include "online_driver_common.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "olc_driver_common.hpp"
#include "conv_tunables.hpp"
#include "handle.hpp"
#include "conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw {
@@ -208,7 +206,7 @@ template <typename TInWei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc(
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
olCompile::Handle* handle,
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,

View File

@@ -1,13 +1,11 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "handle.hpp"
#include "online_driver_common.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
#include "olc_driver_common.hpp"
#include "conv_tunables.hpp"
#include "handle.hpp"
#include "conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp"
namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk {
@@ -209,7 +207,7 @@ template <typename TInWei,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc(
void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk(
olCompile::Handle* handle,
const InLengths& in_n_hi_wi_c_lengths,
const WeiLengths& wei_k_y_x_c_lengths,

View File

@@ -0,0 +1,425 @@
#include "device.hpp"
#include "host_tensor.hpp"
#include "handle.hpp"
#include "online_driver_common.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
#include "conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp"
namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw {
template <typename TInWei, typename TAcc, typename TOut>
static std::string get_network_config_string_from_types()
{
std::string out("DAT_");
out += static_cast<char>(Driver::get_typeid_from_type<TInWei>()) +
static_cast<char>(Driver::get_typeid_from_type<TAcc>()) +
static_cast<char>(Driver::get_typeid_from_type<TOut>());
return (out);
};
static std::string
get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable)
{
std::string out("TUN_");
out += std::to_string(tunable.BlockSize) + "_";
out += std::to_string(tunable.GN0) + "x" + std::to_string(tunable.GK1) + "_";
out += std::to_string(tunable.GM1PerBlockGM11) + "x" + std::to_string(tunable.GN1PerBlockGN11) +
"x" + std::to_string(tunable.GK0PerBlock) + "_";
out += std::to_string(tunable.BM1PerThreadBM11) + "x" +
std::to_string(tunable.BN1PerThreadBN11) + "x" + std::to_string(tunable.BK0PerThread) +
"_";
out += std::to_string(tunable.BM10BN10ThreadClusterBM100) + "x" +
std::to_string(tunable.BM10BN10ThreadClusterBN100) + "x" +
std::to_string(tunable.BM10BN10ThreadClusterBM101) + "x" +
std::to_string(tunable.BM10BN10ThreadClusterBN101) + "_";
out += std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "x" +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "x" +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "x" +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "x" +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]) + "_";
out +=
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "x" +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "x" +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "x" +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "x" +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]) + "_";
out += std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) +
"x" +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) +
"x" +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) +
"x" +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) +
"x" +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) +
"_";
out += std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) +
"x" +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) +
"x" +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) +
"x" +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) +
"x" +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) +
"_";
out += std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "x" +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "x" +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "x" +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "x" +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]) + "_";
out +=
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "x" +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "x" +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "x" +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "x" +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]) + "_";
out += std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) +
"x" +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) +
"x" +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) +
"x" +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) +
"x" +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) +
"_";
out += std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) +
"x" +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) +
"x" +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) +
"x" +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) +
"x" +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) +
"_";
out += std::to_string(tunable.CThreadTransferDstScalarPerVector);
return (out);
};
template <typename TInWei, typename TAcc, typename TOut>
static std::string get_definition_string_from_types()
{
std::string out;
out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TInWei>()) +
" -DCK_PARAM_ACC_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TAcc>()) +
" -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type<TOut>());
return (out);
};
static std::string
get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable)
{
std::string out;
out += " -DCK_PARAM_BlockSize=" + std::to_string(tunable.BlockSize);
out += " -DCK_PARAM_GN0=" + std::to_string(tunable.GN0);
out += " -DCK_PARAM_GK1=" + std::to_string(tunable.GK1);
out += " -DCK_PARAM_GM1PerBlockGM11=" + std::to_string(tunable.GM1PerBlockGM11) +
" -DCK_PARAM_GN1PerBlockGN11=" + std::to_string(tunable.GN1PerBlockGN11) +
" -DCK_PARAM_GK0PerBlock=" + std::to_string(tunable.GK0PerBlock);
out += " -DCK_PARAM_BM1PerThreadBM11=" + std::to_string(tunable.BM1PerThreadBM11) +
" -DCK_PARAM_BN1PerThreadBN11=" + std::to_string(tunable.BN1PerThreadBN11) +
" -DCK_PARAM_BK0PerThread=" + std::to_string(tunable.BK0PerThread);
out += " -DCK_PARAM_BM10BN10ThreadClusterBM100=" +
std::to_string(tunable.BM10BN10ThreadClusterBM100) +
" -DCK_PARAM_BM10BN10ThreadClusterBN100=" +
std::to_string(tunable.BM10BN10ThreadClusterBN100) +
" -DCK_PARAM_BM10BN10ThreadClusterBM101=" +
std::to_string(tunable.BM10BN10ThreadClusterBM101) +
" -DCK_PARAM_BM10BN10ThreadClusterBN101=" +
std::to_string(tunable.BM10BN10ThreadClusterBN101);
out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," +
std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]);
out +=
" -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," +
std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]);
out += " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) +
"," +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) +
"," +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) +
"," +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) +
"," +
std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]);
out += " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) +
"," +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) +
"," +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) +
"," +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) +
"," +
std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]);
out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," +
std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]);
out +=
" -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," +
std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]);
out += " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) +
"," +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) +
"," +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) +
"," +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) +
"," +
std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]);
out += " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) +
"," +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) +
"," +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) +
"," +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) +
"," +
std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]);
out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" +
std::to_string(tunable.CThreadTransferDstScalarPerVector);
return (out);
};
} // namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw
template <typename TInWei,
typename TAcc,
typename TOut,
typename InLengths,
typename WeiLengths,
typename OutLengths,
typename ConvStrides,
typename ConvDilations,
typename InLeftPads,
typename InRightPads>
void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
olCompile::Handle* handle,
const InLengths& in_n_c_hi_wi_lengths,
const WeiLengths& wei_k_c_y_x_lengths,
const OutLengths& out_n_k_ho_wo_lengths,
const ConvStrides& conv_strides,
const ConvDilations& conv_dilations,
const InLeftPads& in_left_pads,
const InRightPads& in_right_pads,
const Tensor<TInWei>& in_n_c_hi_wi,
const Tensor<TInWei>& wei_k_c_y_x,
Tensor<TOut>& out_n_k_ho_wo,
const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable,
ck::index_t nrepeat)
{
using namespace ck;
using namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw;
using size_t = std::size_t;
////////////////////////////////////////////////////////////////////////////////////////////////////////////
// The follow codes are only used for computing the grid_size, hasMainKBlockLoop,
// hasDoubleTailKBlockLoop
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
const auto in_n_c_hi_wi_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
const auto wei_k_c_y_x_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
const auto out_n_k_ho_wo_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
const auto descs =
transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
in_n_c_hi_wi_desc,
out_n_k_ho_wo_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads,
tunable.GN0,
tunable.GK1);
const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0];
const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2];
const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1);
const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3);
const auto GK = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
const index_t grid_size = (GM1 / tunable.GM1PerBlockGM11) * (GN1 / tunable.GN1PerBlockGN11);
const bool hasMainKBlockLoop = ((GK + tunable.GK0PerBlock) / (2 * tunable.GK0PerBlock) > 1);
const bool hasDoubleTailKBlockLoop = ((GK / tunable.GK0PerBlock) % 2 == 0);
///////////////////////////////////////////////////////////////////////////////////////////////////////////
// these buffers are usually provided by the user application
DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data());
wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data());
out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data());
// these are workspace buffers that should be expressed to the user by the corresponding
// workspace API
DeviceMem workspace_buf(4096);
void* a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf = workspace_buf.GetDeviceBuffer();
void* b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf =
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 1024);
void* c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf =
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 2048);
void* c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf =
static_cast<void*>(static_cast<unsigned char*>(workspace_buf.GetDeviceBuffer()) + 3072);
const std::vector<size_t> vld = {static_cast<size_t>(tunable.BlockSize), 1, 1};
const std::vector<size_t> vgd1 = {static_cast<size_t>(tunable.BlockSize), 1, 1};
const std::vector<size_t> vgd2 = {static_cast<size_t>(grid_size * tunable.BlockSize), 1, 1};
std::string program_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp";
std::string algo_name = "implicit_gemm_conv_fwd_v6r1_nchw";
std::string param = " -std=c++17 ";
std::string network_config;
param += get_definition_string_from_types<TInWei, TAcc, TOut>() +
" -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) +
" -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop) +
get_definition_string_from_tunable(tunable);
network_config = get_network_config_string_from_types<TInWei, TAcc, TOut>() + "_" +
std::to_string(hasDoubleTailKBlockLoop) + "_" +
get_network_config_string_from_tunable(tunable);
std::vector<float> kernel1_times;
std::vector<float> kernel2_times;
for(index_t i = 0; i < nrepeat; ++i)
{
KernelTimer timer1, timer2;
std::string kernel_name;
kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare";
auto network_config_1 = network_config + "_1";
timer1.Start();
handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)(
static_cast<index_t>(in_n_c_hi_wi_lengths[I0]),
static_cast<index_t>(in_n_c_hi_wi_lengths[I1]),
static_cast<index_t>(in_n_c_hi_wi_lengths[I2]),
static_cast<index_t>(in_n_c_hi_wi_lengths[I3]),
static_cast<index_t>(wei_k_c_y_x_lengths[I0]),
static_cast<index_t>(wei_k_c_y_x_lengths[I2]),
static_cast<index_t>(wei_k_c_y_x_lengths[I3]),
conv_strides[I0],
conv_strides[I1],
conv_dilations[I0],
conv_dilations[I1],
in_left_pads[I0],
in_left_pads[I1],
in_right_pads[I0],
in_right_pads[I1],
a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf,
b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf,
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf,
c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf);
timer2.End();
kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw";
auto network_config_2 = network_config + "_2";
timer2.Start();
handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)(
reinterpret_cast<const TInWei*>(wei_k_c_y_x_dev_buf.GetDeviceBuffer()),
reinterpret_cast<const TInWei*>(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()),
reinterpret_cast<TOut*>(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()),
(const void*)(a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf),
(const void*)(b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf),
(const void*)(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf),
(const void*)(c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf));
timer2.End();
kernel1_times.push_back(timer1.GetElapsedTime());
kernel2_times.push_back(timer2.GetElapsedTime());
}
{
auto ave_time1 = Driver::get_effective_average(kernel1_times);
auto ave_time2 = Driver::get_effective_average(kernel2_times);
const auto N = in_n_c_hi_wi_lengths[I0];
const auto C = in_n_c_hi_wi_lengths[I1];
const auto K = out_n_k_ho_wo_lengths[I1];
const auto Ho = out_n_k_ho_wo_lengths[I2];
const auto Wo = out_n_k_ho_wo_lengths[I3];
const auto Y = wei_k_c_y_x_lengths[I2];
const auto X = wei_k_c_y_x_lengths[I3];
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
(std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2);
std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", "
<< ave_time2 << "), " << perf << " TFlop/s" << std::endl;
};
// copy result back to host
out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data());
}

View File

@@ -0,0 +1,19 @@
include_directories(BEFORE
include
)
set(HOST_TENSOR_SOURCE
src/host_tensor.cpp;
src/device.cpp;
)
## the library target
add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE})
target_link_libraries(host_tensor PRIVATE hip::device)
target_link_libraries(host_tensor INTERFACE hip::host)
target_compile_features(host_tensor PUBLIC)
set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS host_tensor LIBRARY DESTINATION lib)

View File

@@ -2,7 +2,8 @@
#define DEVICE_HPP
#include <memory>
#include "config.hpp"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
struct DeviceMem
{
@@ -30,7 +31,6 @@ struct KernelTimer
std::unique_ptr<KernelTimerImpl> impl;
};
#if CK_DEVICE_BACKEND_AMD
using device_stream_t = hipStream_t;
template <typename... Args, typename F>
@@ -83,44 +83,4 @@ float launch_and_time_kernel(F kernel,
return timer.GetElapsedTime() / nrepeat;
}
#elif CK_DEVICE_BACKEND_NVIDIA
using device_stream_t = cudaStream_t;
template <typename... Args, typename F>
void launch_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
cudaStream_t stream_id,
Args... args)
{
const void* f = reinterpret_cast<const void*>(kernel);
void* p_args[] = {&args...};
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
}
template <typename... Args, typename F>
float launch_and_time_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
cudaStream_t stream_id,
Args... args)
{
KernelTimer timer;
const void* f = reinterpret_cast<const void*>(kernel);
void* p_args[] = {&args...};
timer.Start();
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
timer.End();
return timer.GetElapsedTime();
}
#endif
#endif

View File

@@ -1,107 +1,59 @@
#include "config.hpp"
#include "device.hpp"
DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
#if CK_DEVICE_BACKEND_AMD
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
#elif CK_DEVICE_BACKEND_NVIDIA
cudaMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize);
#endif
}
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p)
{
#if CK_DEVICE_BACKEND_AMD
hipGetErrorString(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
#elif CK_DEVICE_BACKEND_NVIDIA
cudaMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, cudaMemcpyHostToDevice);
#endif
}
void DeviceMem::FromDevice(void* p)
{
#if CK_DEVICE_BACKEND_AMD
hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
#elif CK_DEVICE_BACKEND_NVIDIA
cudaMemcpy(p, mpDeviceBuf, mMemSize, cudaMemcpyDeviceToHost);
#endif
}
DeviceMem::~DeviceMem()
{
#if CK_DEVICE_BACKEND_AMD
hipGetErrorString(hipFree(mpDeviceBuf));
#elif CK_DEVICE_BACKEND_NVIDIA
cudaFree(mpDeviceBuf);
#endif
}
DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); }
struct KernelTimerImpl
{
KernelTimerImpl()
{
#if CK_DEVICE_BACKEND_AMD
hipEventCreate(&mStart);
hipEventCreate(&mEnd);
#elif CK_DEVICE_BACKEND_NVIDIA
cudaEventCreate(&mStart);
cudaEventCreate(&mEnd);
#endif
}
~KernelTimerImpl()
{
#if CK_DEVICE_BACKEND_AMD
hipEventDestroy(mStart);
hipEventDestroy(mEnd);
#elif CK_DEVICE_BACKEND_NVIDIA
cudaEventDestroy(mStart);
cudaEventDestroy(mEnd);
#endif
}
void Start()
{
#if CK_DEVICE_BACKEND_AMD
hipDeviceSynchronize();
hipEventRecord(mStart, 0);
#elif CK_DEVICE_BACKEND_NVIDIA
cudaDeviceSynchronize();
cudaEventRecord(mStart, 0);
#endif
}
void End()
{
#if CK_DEVICE_BACKEND_AMD
hipEventRecord(mEnd, 0);
hipEventSynchronize(mEnd);
#elif CK_DEVICE_BACKEND_NVIDIA
cudaEventRecord(mEnd, 0);
cudaEventSynchronize(mEnd);
#endif
}
float GetElapsedTime() const
{
float time;
#if CK_DEVICE_BACKEND_AMD
hipEventElapsedTime(&time, mStart, mEnd);
#elif CK_DEVICE_BACKEND_NVIDIA
cudaEventElapsedTime(&time, mStart, mEnd);
#endif
return time;
}
#if CK_DEVICE_BACKEND_AMD
hipEvent_t mStart, mEnd;
#elif CK_DEVICE_BACKEND_NVIDIA
cudaEvent_t mStart, mEnd;
#endif
};
KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {}

View File

@@ -1,4 +1,3 @@
set(CMAKE_CXX_COMPILER /opt/rocm/llvm/bin/clang++)
## for online-compiling of HIP kernels
@@ -17,6 +16,7 @@ if(OLC_HIP_COMPILER MATCHES ".*clang\\+\\+$")
${CMAKE_INSTALL_PREFIX}/llvm
)
endif()
if(OLC_OFFLOADBUNDLER_BIN)
message(STATUS "clang-offload-bundler found: ${OLC_OFFLOADBUNDLER_BIN}")
set(OLC_OFFLOADBUNDLER_BIN "${OLC_OFFLOADBUNDLER_BIN}")
@@ -67,92 +67,58 @@ else()
set(OLC_DEBUG 0)
endif()
configure_file("${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/config.h.in" "${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/config.h")
configure_file("${PROJECT_SOURCE_DIR}/host/online_compilation/include/config.h.in" "${PROJECT_BINARY_DIR}/host/online_compilation/include/config.h")
include_directories(BEFORE
${PROJECT_BINARY_DIR}/host/online_compilation/include
)
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
## HIP_COMPILER_FLAGS will be used for on-line compiling of the HIP kernels
add_definitions("-DHIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}")
file(GLOB COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm/*.hpp")
file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description/*.hpp")
file(GLOB COMPOSABLE_KERNEL_INCLUDE_3 "${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation/*.hpp")
file(GLOB COMPOSABLE_KERNEL_INCLUDE_4 "${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/*.hpp")
file(GLOB COMPOSABLE_KERNEL_INCLUDE_5 "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/*.hpp")
file(GLOB COMPOSABLE_KERNEL_INCLUDE_6 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp")
file(GLOB_RECURSE COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/*/*.hpp")
file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp")
set(MCONV_KERNEL_INCLUDES
${COMPOSABLE_KERNEL_INCLUDE_1}
${COMPOSABLE_KERNEL_INCLUDE_2}
${COMPOSABLE_KERNEL_INCLUDE_3}
${COMPOSABLE_KERNEL_INCLUDE_4}
${COMPOSABLE_KERNEL_INCLUDE_5}
${COMPOSABLE_KERNEL_INCLUDE_6}
)
set(MCONV_KERNELS
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp
../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp
)
file(GLOB_RECURSE MCONV_KERNELS "${PROJECT_SOURCE_DIR}/composable_kernel/src/kernel_wrapper/*.cpp")
add_kernels("olCompiling/" "${MCONV_KERNELS}")
add_kernel_includes("olCompiling/" "${MCONV_KERNEL_INCLUDES}")
add_kernels(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNELS}")
add_kernel_includes(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNEL_INCLUDES}")
set(MCONV_SOURCES
src/host_tensor.cpp;
src/device.cpp;
set(ONLINE_COMPILATION_SOURCE
${PROJECT_BINARY_DIR}/kernel.cpp
${PROJECT_BINARY_DIR}/kernel_includes.cpp
)
set(OLC_HIP_UTILITY_HEADERS
olCompiling/include/config.h
olCompiling/include/logger.hpp
olCompiling/include/stringutils.hpp
olCompiling/include/tmp_dir.hpp
olCompiling/include/write_file.hpp
olCompiling/include/env.hpp
olCompiling/include/manage_ptr.hpp
olCompiling/include/md5.hpp
olCompiling/include/simple_hash.hpp
olCompiling/include/exec_utils.hpp
olCompiling/include/hipCheck.hpp
olCompiling/include/target_properties.hpp
olCompiling/include/handle.hpp
olCompiling/include/op_kernel_args.hpp
olCompiling/include/kernel.hpp
olCompiling/include/kernel_build_params.hpp
olCompiling/include/hip_build_utils.hpp
olCompiling/include/hipoc_program.hpp
olCompiling/include/hipoc_program_impl.hpp
olCompiling/include/hipoc_kernel.hpp
olCompiling/include/kernel_cache.hpp
olCompiling/include/binary_cache.hpp
)
include_directories(BEFORE
${PROJECT_BINARY_DIR}/host/online_compilation/include
include
)
set(OLC_HIP_UTILITY_CPPS
olCompiling/hip_utility/logger.cpp
olCompiling/hip_utility/tmp_dir.cpp
olCompiling/hip_utility/md5.cpp
olCompiling/hip_utility/exec_utils.cpp
olCompiling/hip_utility/target_properties.cpp
olCompiling/hip_utility/handlehip.cpp
olCompiling/hip_utility/kernel_build_params.cpp
olCompiling/hip_utility/hip_build_utils.cpp
olCompiling/hip_utility/hipoc_program.cpp
olCompiling/hip_utility/hipoc_kernel.cpp
olCompiling/hip_utility/kernel_cache.cpp
olCompiling/hip_utility/binary_cache.cpp
hip_utility/logger.cpp
hip_utility/tmp_dir.cpp
hip_utility/md5.cpp
hip_utility/exec_utils.cpp
hip_utility/target_properties.cpp
hip_utility/handlehip.cpp
hip_utility/kernel_build_params.cpp
hip_utility/hip_build_utils.cpp
hip_utility/hipoc_program.cpp
hip_utility/hipoc_kernel.cpp
hip_utility/kernel_cache.cpp
hip_utility/binary_cache.cpp
)
list(APPEND OLC_SOURCES ${OLC_HIP_UTILITY_CPPS} ${OLC_HIP_UTILITY_HEADERS})
list(INSERT MCONV_SOURCES 0
${PROJECT_BINARY_DIR}/kernel.cpp
${PROJECT_BINARY_DIR}/kernel_includes.cpp
)
## addkernels provide the tool to create inlined kernels in one header
add_subdirectory(olCompiling/addkernels)
add_subdirectory(addkernels)
function(inline_kernels_src KERNELS KERNEL_INCLUDES)
set(KERNEL_SRC_HPP_FILENAME batch_all.cpp.hpp)
@@ -166,7 +132,7 @@ function(inline_kernels_src KERNELS KERNEL_INCLUDES)
COMMAND $<TARGET_FILE:addkernels> -target ${KERNEL_SRC_HPP_PATH} -extern -source ${KERNELS}
COMMENT "Inlining All kernels"
)
configure_file(olCompiling/kernels_batch.cpp.in ${KERNEL_SRC_CPP_PATH})
configure_file(kernels_batch.cpp.in ${KERNEL_SRC_CPP_PATH})
list(APPEND OLC_SOURCES ${KERNEL_SRC_CPP_PATH} ${KERNEL_SRC_HPP_PATH})
set(OLC_SOURCES ${OLC_SOURCES} PARENT_SCOPE)
@@ -174,7 +140,7 @@ endfunction()
inline_kernels_src("${MCONV_KERNELS}" "${MCONV_KERNEL_INCLUDES}")
list(APPEND MCONV_SOURCES ${OLC_SOURCES} ${PROJECT_BINARY_DIR}/olc_kernel_includes.h)
list(APPEND ONLINE_COMPILATION_SOURCE ${OLC_SOURCES} ${PROJECT_BINARY_DIR}/olc_kernel_includes.h)
add_custom_command(
OUTPUT ${PROJECT_BINARY_DIR}/olc_kernel_includes.h
@@ -185,19 +151,17 @@ add_custom_command(
)
## the library target
add_library(modConv SHARED ${MCONV_SOURCES})
add_library(online_compilation SHARED ${ONLINE_COMPILATION_SOURCE})
target_include_directories(modConv PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/)
target_include_directories(modConv PRIVATE ${PROJECT_BINARY_DIR})
target_include_directories(modConv PRIVATE ${PROJECT_SOURCE_DIR}/external/half/include/)
target_include_directories(online_compilation PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/online_compilation/include/)
target_include_directories(online_compilation PRIVATE ${PROJECT_BINARY_DIR})
target_include_directories(online_compilation PRIVATE ${PROJECT_SOURCE_DIR}/external/half/include/)
target_link_libraries(modConv PRIVATE hip::device)
target_link_libraries(modConv INTERFACE hip::host)
target_link_libraries(modConv PRIVATE Boost::filesystem)
target_link_libraries(online_compilation PRIVATE hip::device)
target_link_libraries(online_compilation INTERFACE hip::host)
target_link_libraries(online_compilation PRIVATE Boost::filesystem)
target_compile_options(modConv PRIVATE -mfma)
target_compile_features(online_compilation PUBLIC)
set_target_properties(online_compilation PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_features(modConv PUBLIC)
set_target_properties(modConv PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS modConv LIBRARY DESTINATION lib)
install(TARGETS online_compilation LIBRARY DESTINATION lib)

Some files were not shown because too many files have changed in this diff Show More